| | |
| |
|
| | #include <numeric> |
| | #include <ostream> |
| | #include <variant> |
| |
|
| | #include <nanobind/nanobind.h> |
| | #include <nanobind/stl/optional.h> |
| | #include <nanobind/stl/pair.h> |
| | #include <nanobind/stl/string.h> |
| | #include <nanobind/stl/tuple.h> |
| | #include <nanobind/stl/variant.h> |
| | #include <nanobind/stl/vector.h> |
| |
|
| | #include "mlx/einsum.h" |
| | #include "mlx/ops.h" |
| | #include "mlx/utils.h" |
| | #include "python/src/load.h" |
| | #include "python/src/small_vector.h" |
| | #include "python/src/utils.h" |
| |
|
| | namespace mx = mlx::core; |
| | namespace nb = nanobind; |
| | using namespace nb::literals; |
| |
|
| | using Scalar = std::variant<bool, int, double>; |
| |
|
| | mx::Dtype scalar_to_dtype(Scalar s) { |
| | if (std::holds_alternative<int>(s)) { |
| | return mx::int32; |
| | } else if (std::holds_alternative<double>(s)) { |
| | return mx::float32; |
| | } else { |
| | return mx::bool_; |
| | } |
| | } |
| |
|
| | double scalar_to_double(Scalar s) { |
| | if (auto pv = std::get_if<int>(&s); pv) { |
| | return static_cast<double>(*pv); |
| | } else if (auto pv = std::get_if<double>(&s); pv) { |
| | return *pv; |
| | } else { |
| | return static_cast<double>(std::get<bool>(s)); |
| | } |
| | } |
| |
|
| | void init_ops(nb::module_& m) { |
| | m.def( |
| | "reshape", |
| | &mx::reshape, |
| | nb::arg(), |
| | "shape"_a, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig("def reshape(a: array, /, shape: Sequence[int], *, stream: " |
| | "Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Reshape an array while preserving the size. |
| | |
| | Args: |
| | a (array): Input array. |
| | shape (tuple(int)): New shape. |
| | stream (Stream, optional): Stream or device. Defaults to ``None`` |
| | in which case the default stream of the default device is used. |
| | |
| | Returns: |
| | array: The reshaped array. |
| | )pbdoc"); |
| | m.def( |
| | "flatten", |
| | [](const mx::array& a, |
| | int start_axis, |
| | int end_axis, |
| | const mx::StreamOrDevice& s) { |
| | return mx::flatten(a, start_axis, end_axis); |
| | }, |
| | nb::arg(), |
| | "start_axis"_a = 0, |
| | "end_axis"_a = -1, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig("def flatten(a: array, /, start_axis: int = 0, end_axis: int = " |
| | "-1, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Flatten an array. |
| | |
| | The axes flattened will be between ``start_axis`` and ``end_axis``, |
| | inclusive. Negative axes are supported. After converting negative axis to |
| | positive, axes outside the valid range will be clamped to a valid value, |
| | ``start_axis`` to ``0`` and ``end_axis`` to ``ndim - 1``. |
| | |
| | Args: |
| | a (array): Input array. |
| | start_axis (int, optional): The first dimension to flatten. Defaults to ``0``. |
| | end_axis (int, optional): The last dimension to flatten. Defaults to ``-1``. |
| | stream (Stream, optional): Stream or device. Defaults to ``None`` |
| | in which case the default stream of the default device is used. |
| | |
| | Returns: |
| | array: The flattened array. |
| | |
| | Example: |
| | >>> a = mx.array([[1, 2], [3, 4]]) |
| | >>> mx.flatten(a) |
| | array([1, 2, 3, 4], dtype=int32) |
| | >>> |
| | >>> mx.flatten(a, start_axis=0, end_axis=-1) |
| | array([1, 2, 3, 4], dtype=int32) |
| | )pbdoc"); |
| | m.def( |
| | "unflatten", |
| | &mx::unflatten, |
| | nb::arg(), |
| | "axis"_a, |
| | "shape"_a, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def unflatten(a: array, /, axis: int, shape: Sequence[int], *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Unflatten an axis of an array to a shape. |
| | |
| | Args: |
| | a (array): Input array. |
| | axis (int): The axis to unflatten. |
| | shape (tuple(int)): The shape to unflatten to. At most one |
| | entry can be ``-1`` in which case the corresponding size will be |
| | inferred. |
| | stream (Stream, optional): Stream or device. Defaults to ``None`` |
| | in which case the default stream of the default device is used. |
| | |
| | Returns: |
| | array: The unflattened array. |
| | |
| | Example: |
| | >>> a = mx.array([1, 2, 3, 4]) |
| | >>> mx.unflatten(a, 0, (2, -1)) |
| | array([[1, 2], [3, 4]], dtype=int32) |
| | )pbdoc"); |
| | m.def( |
| | "squeeze", |
| | [](const mx::array& a, const IntOrVec& v, const mx::StreamOrDevice& s) { |
| | if (std::holds_alternative<std::monostate>(v)) { |
| | return mx::squeeze(a, s); |
| | } else if (auto pv = std::get_if<int>(&v); pv) { |
| | return mx::squeeze(a, *pv, s); |
| | } else { |
| | return mx::squeeze(a, std::get<std::vector<int>>(v), s); |
| | } |
| | }, |
| | nb::arg(), |
| | "axis"_a = nb::none(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def squeeze(a: array, /, axis: Union[None, int, Sequence[int]] = " |
| | "None, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Remove length one axes from an array. |
| | |
| | Args: |
| | a (array): Input array. |
| | axis (int or tuple(int), optional): Axes to remove. Defaults |
| | to ``None`` in which case all size one axes are removed. |
| | |
| | Returns: |
| | array: The output array with size one axes removed. |
| | )pbdoc"); |
| | m.def( |
| | "expand_dims", |
| | [](const mx::array& a, |
| | const std::variant<int, std::vector<int>>& v, |
| | mx::StreamOrDevice s) { |
| | if (auto pv = std::get_if<int>(&v); pv) { |
| | return mx::expand_dims(a, *pv, s); |
| | } else { |
| | return mx::expand_dims(a, std::get<std::vector<int>>(v), s); |
| | } |
| | }, |
| | nb::arg(), |
| | "axis"_a, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig("def expand_dims(a: array, /, axis: Union[int, Sequence[int]], " |
| | "*, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Add a size one dimension at the given axis. |
| | |
| | Args: |
| | a (array): Input array. |
| | axes (int or tuple(int)): The index of the inserted dimensions. |
| | |
| | Returns: |
| | array: The array with inserted dimensions. |
| | )pbdoc"); |
| | m.def( |
| | "abs", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::abs(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def abs(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise absolute value. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The absolute value of ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "sign", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::sign(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def sign(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise sign. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The sign of ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "negative", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::negative(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def negative(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise negation. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The negative of ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "add", |
| | [](const ScalarOrArray& a_, |
| | const ScalarOrArray& b_, |
| | mx::StreamOrDevice s) { |
| | auto [a, b] = to_arrays(a_, b_); |
| | return mx::add(a, b, s); |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def add(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise addition. |
| | |
| | Add two arrays with numpy-style broadcasting semantics. Either or both input arrays |
| | can also be scalars. |
| | |
| | Args: |
| | a (array): Input array or scalar. |
| | b (array): Input array or scalar. |
| | |
| | Returns: |
| | array: The sum of ``a`` and ``b``. |
| | )pbdoc"); |
| | m.def( |
| | "subtract", |
| | [](const ScalarOrArray& a_, |
| | const ScalarOrArray& b_, |
| | mx::StreamOrDevice s) { |
| | auto [a, b] = to_arrays(a_, b_); |
| | return mx::subtract(a, b, s); |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def subtract(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise subtraction. |
| | |
| | Subtract one array from another with numpy-style broadcasting semantics. Either or both |
| | input arrays can also be scalars. |
| | |
| | Args: |
| | a (array): Input array or scalar. |
| | b (array): Input array or scalar. |
| | |
| | Returns: |
| | array: The difference ``a - b``. |
| | )pbdoc"); |
| | m.def( |
| | "multiply", |
| | [](const ScalarOrArray& a_, |
| | const ScalarOrArray& b_, |
| | mx::StreamOrDevice s) { |
| | auto [a, b] = to_arrays(a_, b_); |
| | return mx::multiply(a, b, s); |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def multiply(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise multiplication. |
| | |
| | Multiply two arrays with numpy-style broadcasting semantics. Either or both |
| | input arrays can also be scalars. |
| | |
| | Args: |
| | a (array): Input array or scalar. |
| | b (array): Input array or scalar. |
| | |
| | Returns: |
| | array: The multiplication ``a * b``. |
| | )pbdoc"); |
| | m.def( |
| | "divide", |
| | [](const ScalarOrArray& a_, |
| | const ScalarOrArray& b_, |
| | mx::StreamOrDevice s) { |
| | auto [a, b] = to_arrays(a_, b_); |
| | return mx::divide(a, b, s); |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def divide(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise division. |
| | |
| | Divide two arrays with numpy-style broadcasting semantics. Either or both |
| | input arrays can also be scalars. |
| | |
| | Args: |
| | a (array): Input array or scalar. |
| | b (array): Input array or scalar. |
| | |
| | Returns: |
| | array: The quotient ``a / b``. |
| | )pbdoc"); |
| | m.def( |
| | "divmod", |
| | [](const ScalarOrArray& a_, |
| | const ScalarOrArray& b_, |
| | mx::StreamOrDevice s) { |
| | auto [a, b] = to_arrays(a_, b_); |
| | return mx::divmod(a, b, s); |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def divmod(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise quotient and remainder. |
| | |
| | The fuction ``divmod(a, b)`` is equivalent to but faster than |
| | ``(a // b, a % b)``. The function uses numpy-style broadcasting |
| | semantics. Either or both input arrays can also be scalars. |
| | |
| | Args: |
| | a (array): Input array or scalar. |
| | b (array): Input array or scalar. |
| | |
| | Returns: |
| | tuple(array, array): The quotient ``a // b`` and remainder ``a % b``. |
| | )pbdoc"); |
| | m.def( |
| | "floor_divide", |
| | [](const ScalarOrArray& a_, |
| | const ScalarOrArray& b_, |
| | mx::StreamOrDevice s) { |
| | auto [a, b] = to_arrays(a_, b_); |
| | return mx::floor_divide(a, b, s); |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def floor_divide(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise integer division. |
| | |
| | If either array is a floating point type then it is equivalent to |
| | calling :func:`floor` after :func:`divide`. |
| | |
| | Args: |
| | a (array): Input array or scalar. |
| | b (array): Input array or scalar. |
| | |
| | Returns: |
| | array: The quotient ``a // b``. |
| | )pbdoc"); |
| | m.def( |
| | "remainder", |
| | [](const ScalarOrArray& a_, |
| | const ScalarOrArray& b_, |
| | mx::StreamOrDevice s) { |
| | auto [a, b] = to_arrays(a_, b_); |
| | return mx::remainder(a, b, s); |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def remainder(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise remainder of division. |
| | |
| | Computes the remainder of dividing a with b with numpy-style |
| | broadcasting semantics. Either or both input arrays can also be |
| | scalars. |
| | |
| | Args: |
| | a (array): Input array or scalar. |
| | b (array): Input array or scalar. |
| | |
| | Returns: |
| | array: The remainder of ``a // b``. |
| | )pbdoc"); |
| | m.def( |
| | "equal", |
| | [](const ScalarOrArray& a_, |
| | const ScalarOrArray& b_, |
| | mx::StreamOrDevice s) { |
| | auto [a, b] = to_arrays(a_, b_); |
| | return mx::equal(a, b, s); |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise equality. |
| | |
| | Equality comparison on two arrays with numpy-style broadcasting semantics. |
| | Either or both input arrays can also be scalars. |
| | |
| | Args: |
| | a (array): Input array or scalar. |
| | b (array): Input array or scalar. |
| | |
| | Returns: |
| | array: The element-wise comparison ``a == b``. |
| | )pbdoc"); |
| | m.def( |
| | "not_equal", |
| | [](const ScalarOrArray& a_, |
| | const ScalarOrArray& b_, |
| | mx::StreamOrDevice s) { |
| | auto [a, b] = to_arrays(a_, b_); |
| | return mx::not_equal(a, b, s); |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def not_equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise not equal. |
| | |
| | Not equal comparison on two arrays with numpy-style broadcasting semantics. |
| | Either or both input arrays can also be scalars. |
| | |
| | Args: |
| | a (array): Input array or scalar. |
| | b (array): Input array or scalar. |
| | |
| | Returns: |
| | array: The element-wise comparison ``a != b``. |
| | )pbdoc"); |
| | m.def( |
| | "less", |
| | [](const ScalarOrArray& a_, |
| | const ScalarOrArray& b_, |
| | mx::StreamOrDevice s) { |
| | auto [a, b] = to_arrays(a_, b_); |
| | return mx::less(a, b, s); |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def less(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise less than. |
| | |
| | Strict less than on two arrays with numpy-style broadcasting semantics. |
| | Either or both input arrays can also be scalars. |
| | |
| | Args: |
| | a (array): Input array or scalar. |
| | b (array): Input array or scalar. |
| | |
| | Returns: |
| | array: The element-wise comparison ``a < b``. |
| | )pbdoc"); |
| | m.def( |
| | "less_equal", |
| | [](const ScalarOrArray& a_, |
| | const ScalarOrArray& b_, |
| | mx::StreamOrDevice s) { |
| | auto [a, b] = to_arrays(a_, b_); |
| | return mx::less_equal(a, b, s); |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def less_equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise less than or equal. |
| | |
| | Less than or equal on two arrays with numpy-style broadcasting semantics. |
| | Either or both input arrays can also be scalars. |
| | |
| | Args: |
| | a (array): Input array or scalar. |
| | b (array): Input array or scalar. |
| | |
| | Returns: |
| | array: The element-wise comparison ``a <= b``. |
| | )pbdoc"); |
| | m.def( |
| | "greater", |
| | [](const ScalarOrArray& a_, |
| | const ScalarOrArray& b_, |
| | mx::StreamOrDevice s) { |
| | auto [a, b] = to_arrays(a_, b_); |
| | return mx::greater(a, b, s); |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def greater(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise greater than. |
| | |
| | Strict greater than on two arrays with numpy-style broadcasting semantics. |
| | Either or both input arrays can also be scalars. |
| | |
| | Args: |
| | a (array): Input array or scalar. |
| | b (array): Input array or scalar. |
| | |
| | Returns: |
| | array: The element-wise comparison ``a > b``. |
| | )pbdoc"); |
| | m.def( |
| | "greater_equal", |
| | [](const ScalarOrArray& a_, |
| | const ScalarOrArray& b_, |
| | mx::StreamOrDevice s) { |
| | auto [a, b] = to_arrays(a_, b_); |
| | return mx::greater_equal(a, b, s); |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def greater_equal(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise greater or equal. |
| | |
| | Greater than or equal on two arrays with numpy-style broadcasting semantics. |
| | Either or both input arrays can also be scalars. |
| | |
| | Args: |
| | a (array): Input array or scalar. |
| | b (array): Input array or scalar. |
| | |
| | Returns: |
| | array: The element-wise comparison ``a >= b``. |
| | )pbdoc"); |
| | m.def( |
| | "array_equal", |
| | [](const ScalarOrArray& a_, |
| | const ScalarOrArray& b_, |
| | bool equal_nan, |
| | mx::StreamOrDevice s) { |
| | auto [a, b] = to_arrays(a_, b_); |
| | return mx::array_equal(a, b, equal_nan, s); |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | nb::kw_only(), |
| | "equal_nan"_a = false, |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def array_equal(a: Union[scalar, array], b: Union[scalar, array], equal_nan: bool = False, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Array equality check. |
| | |
| | Compare two arrays for equality. Returns ``True`` if and only if the arrays |
| | have the same shape and their values are equal. The arrays need not have |
| | the same type to be considered equal. |
| | |
| | Args: |
| | a (array): Input array or scalar. |
| | b (array): Input array or scalar. |
| | equal_nan (bool): If ``True``, NaNs are considered equal. |
| | Defaults to ``False``. |
| | |
| | Returns: |
| | array: A scalar boolean array. |
| | )pbdoc"); |
| | m.def( |
| | "matmul", |
| | &mx::matmul, |
| | nb::arg(), |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def matmul(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Matrix multiplication. |
| | |
| | Perform the (possibly batched) matrix multiplication of two arrays. This function supports |
| | broadcasting for arrays with more than two dimensions. |
| | |
| | - If the first array is 1-D then a 1 is prepended to its shape to make it |
| | a matrix. Similarly if the second array is 1-D then a 1 is appended to its |
| | shape to make it a matrix. In either case the singleton dimension is removed |
| | from the result. |
| | - A batched matrix multiplication is performed if the arrays have more than |
| | 2 dimensions. The matrix dimensions for the matrix product are the last |
| | two dimensions of each input. |
| | - All but the last two dimensions of each input are broadcast with one another using |
| | standard numpy-style broadcasting semantics. |
| | |
| | Args: |
| | a (array): Input array or scalar. |
| | b (array): Input array or scalar. |
| | |
| | Returns: |
| | array: The matrix product of ``a`` and ``b``. |
| | )pbdoc"); |
| | m.def( |
| | "square", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::square(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def square(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise square. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The square of ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "sqrt", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::sqrt(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def sqrt(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise square root. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The square root of ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "rsqrt", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::rsqrt(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def rsqrt(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise reciprocal and square root. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: One over the square root of ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "reciprocal", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::reciprocal(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def reciprocal(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise reciprocal. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The reciprocal of ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "logical_not", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::logical_not(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def logical_not(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise logical not. |
| | |
| | Args: |
| | a (array): Input array or scalar. |
| | |
| | Returns: |
| | array: The boolean array containing the logical not of ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "logical_and", |
| | [](const ScalarOrArray& a, const ScalarOrArray& b, mx::StreamOrDevice s) { |
| | return mx::logical_and(to_array(a), to_array(b), s); |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def logical_and(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise logical and. |
| | |
| | Args: |
| | a (array): First input array or scalar. |
| | b (array): Second input array or scalar. |
| | |
| | Returns: |
| | array: The boolean array containing the logical and of ``a`` and ``b``. |
| | )pbdoc"); |
| |
|
| | m.def( |
| | "logical_or", |
| | [](const ScalarOrArray& a, const ScalarOrArray& b, mx::StreamOrDevice s) { |
| | return mx::logical_or(to_array(a), to_array(b), s); |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def logical_or(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise logical or. |
| | |
| | Args: |
| | a (array): First input array or scalar. |
| | b (array): Second input array or scalar. |
| | |
| | Returns: |
| | array: The boolean array containing the logical or of ``a`` and ``b``. |
| | )pbdoc"); |
| | m.def( |
| | "logaddexp", |
| | [](const ScalarOrArray& a_, |
| | const ScalarOrArray& b_, |
| | mx::StreamOrDevice s) { |
| | auto [a, b] = to_arrays(a_, b_); |
| | return mx::logaddexp(a, b, s); |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def logaddexp(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise log-add-exp. |
| | |
| | This is a numerically stable log-add-exp of two arrays with numpy-style |
| | broadcasting semantics. Either or both input arrays can also be scalars. |
| | |
| | The computation is is a numerically stable version of ``log(exp(a) + exp(b))``. |
| | |
| | Args: |
| | a (array): Input array or scalar. |
| | b (array): Input array or scalar. |
| | |
| | Returns: |
| | array: The log-add-exp of ``a`` and ``b``. |
| | )pbdoc"); |
| | m.def( |
| | "exp", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::exp(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def exp(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise exponential. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The exponential of ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "expm1", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::expm1(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def expm1(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise exponential minus 1. |
| | |
| | Computes ``exp(x) - 1`` with greater precision for small ``x``. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The expm1 of ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "erf", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::erf(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def erf(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise error function. |
| | |
| | .. math:: |
| | \mathrm{erf}(x) = \frac{2}{\sqrt{\pi}} \int_0^x e^{-t^2} \, dt |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The error function of ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "erfinv", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::erfinv(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def erfinv(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise inverse of :func:`erf`. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The inverse error function of ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "sin", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::sin(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def sin(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise sine. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The sine of ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "cos", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::cos(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def cos(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise cosine. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The cosine of ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "tan", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::tan(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def tan(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise tangent. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The tangent of ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "arcsin", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::arcsin(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def arcsin(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise inverse sine. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The inverse sine of ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "arccos", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::arccos(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def arccos(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise inverse cosine. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The inverse cosine of ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "arctan", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::arctan(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def arctan(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise inverse tangent. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The inverse tangent of ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "arctan2", |
| | &mx::arctan2, |
| | nb::arg(), |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def arctan2(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise inverse tangent of the ratio of two arrays. |
| | |
| | Args: |
| | a (array): Input array. |
| | b (array): Input array. |
| | |
| | Returns: |
| | array: The inverse tangent of the ratio of ``a`` and ``b``. |
| | )pbdoc"); |
| | m.def( |
| | "sinh", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::sinh(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def sinh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise hyperbolic sine. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The hyperbolic sine of ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "cosh", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::cosh(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def cosh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise hyperbolic cosine. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The hyperbolic cosine of ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "tanh", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::tanh(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def tanh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise hyperbolic tangent. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The hyperbolic tangent of ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "arcsinh", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::arcsinh(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def arcsinh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise inverse hyperbolic sine. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The inverse hyperbolic sine of ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "arccosh", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::arccosh(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def arccosh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise inverse hyperbolic cosine. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The inverse hyperbolic cosine of ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "arctanh", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::arctanh(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def arctanh(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise inverse hyperbolic tangent. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The inverse hyperbolic tangent of ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "degrees", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::degrees(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def degrees(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Convert angles from radians to degrees. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The angles in degrees. |
| | )pbdoc"); |
| | m.def( |
| | "radians", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::radians(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def radians(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Convert angles from degrees to radians. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The angles in radians. |
| | )pbdoc"); |
| | m.def( |
| | "log", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::log(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def log(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise natural logarithm. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The natural logarithm of ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "log2", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::log2(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def log2(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise base-2 logarithm. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The base-2 logarithm of ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "log10", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::log10(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def log10(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise base-10 logarithm. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The base-10 logarithm of ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "log1p", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::log1p(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def log1p(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise natural log of one plus the array. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The natural logarithm of one plus ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "stop_gradient", |
| | &mx::stop_gradient, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def stop_gradient(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Stop gradients from being computed. |
| | |
| | The operation is the identity but it prevents gradients from flowing |
| | through the array. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: |
| | The unchanged input ``a`` but without gradient flowing |
| | through it. |
| | )pbdoc"); |
| | m.def( |
| | "sigmoid", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::sigmoid(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def sigmoid(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise logistic sigmoid. |
| | |
| | The logistic sigmoid function is: |
| | |
| | .. math:: |
| | \mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}} |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The logistic sigmoid of ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "power", |
| | [](const ScalarOrArray& a_, |
| | const ScalarOrArray& b_, |
| | mx::StreamOrDevice s) { |
| | auto [a, b] = to_arrays(a_, b_); |
| | return mx::power(a, b, s); |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def power(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise power operation. |
| | |
| | Raise the elements of a to the powers in elements of b with numpy-style |
| | broadcasting semantics. Either or both input arrays can also be scalars. |
| | |
| | Args: |
| | a (array): Input array or scalar. |
| | b (array): Input array or scalar. |
| | |
| | Returns: |
| | array: Bases of ``a`` raised to powers in ``b``. |
| | )pbdoc"); |
| | m.def( |
| | "arange", |
| | [](Scalar start, |
| | Scalar stop, |
| | const std::optional<Scalar>& step, |
| | const std::optional<mx::Dtype>& dtype_, |
| | mx::StreamOrDevice s) { |
| | |
| | mx::Dtype dtype = dtype_ |
| | ? *dtype_ |
| | : mx::promote_types( |
| | scalar_to_dtype(start), |
| | step ? mx::promote_types( |
| | scalar_to_dtype(stop), scalar_to_dtype(*step)) |
| | : scalar_to_dtype(stop)); |
| | return mx::arange( |
| | scalar_to_double(start), |
| | scalar_to_double(stop), |
| | step ? scalar_to_double(*step) : 1.0, |
| | dtype, |
| | s); |
| | }, |
| | "start"_a.noconvert(), |
| | "stop"_a.noconvert(), |
| | "step"_a.noconvert() = nb::none(), |
| | "dtype"_a = nb::none(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def arange(start : Union[int, float], stop : Union[int, float], step : Union[None, int, float], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Generates ranges of numbers. |
| | |
| | Generate numbers in the half-open interval ``[start, stop)`` in |
| | increments of ``step``. |
| | |
| | Args: |
| | start (float or int, optional): Starting value which defaults to ``0``. |
| | stop (float or int): Stopping value. |
| | step (float or int, optional): Increment which defaults to ``1``. |
| | dtype (Dtype, optional): Specifies the data type of the output. If unspecified will default to ``float32`` if any of ``start``, ``stop``, or ``step`` are ``float``. Otherwise will default to ``int32``. |
| | |
| | Returns: |
| | array: The range of values. |
| | |
| | Note: |
| | Following the Numpy convention the actual increment used to |
| | generate numbers is ``dtype(start + step) - dtype(start)``. |
| | This can lead to unexpected results for example if `start + step` |
| | is a fractional value and the `dtype` is integral. |
| | )pbdoc"); |
| | m.def( |
| | "arange", |
| | [](Scalar stop, |
| | const std::optional<Scalar>& step, |
| | const std::optional<mx::Dtype>& dtype_, |
| | mx::StreamOrDevice s) { |
| | mx::Dtype dtype = dtype_ ? *dtype_ |
| | : step |
| | ? mx::promote_types(scalar_to_dtype(stop), scalar_to_dtype(*step)) |
| | : scalar_to_dtype(stop); |
| | return mx::arange( |
| | 0.0, |
| | scalar_to_double(stop), |
| | step ? scalar_to_double(*step) : 1.0, |
| | dtype, |
| | s); |
| | }, |
| | "stop"_a.noconvert(), |
| | "step"_a.noconvert() = nb::none(), |
| | "dtype"_a = nb::none(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def arange(stop : Union[int, float], step : Union[None, int, float] = None, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array")); |
| | m.def( |
| | "linspace", |
| | [](Scalar start, |
| | Scalar stop, |
| | int num, |
| | std::optional<mx::Dtype> dtype, |
| | mx::StreamOrDevice s) { |
| | return mx::linspace( |
| | scalar_to_double(start), |
| | scalar_to_double(stop), |
| | num, |
| | dtype.value_or(mx::float32), |
| | s); |
| | }, |
| | "start"_a, |
| | "stop"_a, |
| | "num"_a = 50, |
| | "dtype"_a.none() = mx::float32, |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def linspace(start, stop, num: Optional[int] = 50, dtype: Optional[Dtype] = float32, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Generate ``num`` evenly spaced numbers over interval ``[start, stop]``. |
| | |
| | Args: |
| | start (scalar): Starting value. |
| | stop (scalar): Stopping value. |
| | num (int, optional): Number of samples, defaults to ``50``. |
| | dtype (Dtype, optional): Specifies the data type of the output, |
| | default to ``float32``. |
| | |
| | Returns: |
| | array: The range of values. |
| | )pbdoc"); |
| | m.def( |
| | "kron", |
| | &mx::kron, |
| | nb::arg("a"), |
| | nb::arg("b"), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def kron(a: array, b: array, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Compute the Kronecker product of two arrays ``a`` and ``b``. |
| | |
| | Args: |
| | a (array): The first input array. |
| | b (array): The second input array. |
| | stream (Union[None, Stream, Device], optional): Optional stream or |
| | device for execution. Default: ``None``. |
| | |
| | Returns: |
| | array: The Kronecker product of ``a`` and ``b``. |
| | |
| | Examples: |
| | >>> a = mx.array([[1, 2], [3, 4]]) |
| | >>> b = mx.array([[0, 5], [6, 7]]) |
| | >>> result = mx.kron(a, b) |
| | >>> print(result) |
| | array([[0, 5, 0, 10], |
| | [6, 7, 12, 14], |
| | [0, 15, 0, 20], |
| | [18, 21, 24, 28]], dtype=int32) |
| | )pbdoc"); |
| | m.def( |
| | "take", |
| | [](const mx::array& a, |
| | const std::variant<nb::int_, mx::array>& indices, |
| | const std::optional<int>& axis, |
| | mx::StreamOrDevice s) { |
| | if (auto pv = std::get_if<nb::int_>(&indices); pv) { |
| | auto idx = nb::cast<int>(*pv); |
| | return axis ? mx::take(a, idx, axis.value(), s) : mx::take(a, idx, s); |
| | } else { |
| | auto indices_ = std::get<mx::array>(indices); |
| | return axis ? mx::take(a, indices_, axis.value(), s) |
| | : mx::take(a, indices_, s); |
| | } |
| | }, |
| | nb::arg(), |
| | "indices"_a, |
| | "axis"_a = nb::none(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def take(a: array, /, indices: Union[int, array], axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Take elements along an axis. |
| | |
| | The elements are taken from ``indices`` along the specified axis. |
| | If the axis is not specified the array is treated as a flattened |
| | 1-D array prior to performing the take. |
| | |
| | As an example, if the ``axis=1`` this is equivalent to ``a[:, indices, ...]``. |
| | |
| | Args: |
| | a (array): Input array. |
| | indices (int or array): Integer index or input array with integral type. |
| | axis (int, optional): Axis along which to perform the take. If unspecified |
| | the array is treated as a flattened 1-D vector. |
| | |
| | Returns: |
| | array: The indexed values of ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "take_along_axis", |
| | [](const mx::array& a, |
| | const mx::array& indices, |
| | const std::optional<int>& axis, |
| | mx::StreamOrDevice s) { |
| | if (axis.has_value()) { |
| | return mx::take_along_axis(a, indices, axis.value(), s); |
| | } else { |
| | return mx::take_along_axis(mx::reshape(a, {-1}, s), indices, 0, s); |
| | } |
| | }, |
| | nb::arg(), |
| | "indices"_a, |
| | "axis"_a.none(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def take_along_axis(a: array, /, indices: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Take values along an axis at the specified indices. |
| | |
| | Args: |
| | a (array): Input array. |
| | indices (array): Indices array. These should be broadcastable with |
| | the input array excluding the `axis` dimension. |
| | axis (int or None): Axis in the input to take the values from. If |
| | ``axis == None`` the array is flattened to 1D prior to the indexing |
| | operation. |
| | |
| | Returns: |
| | array: The output array. |
| | )pbdoc"); |
| | m.def( |
| | "put_along_axis", |
| | [](const mx::array& a, |
| | const mx::array& indices, |
| | const mx::array& values, |
| | const std::optional<int>& axis, |
| | mx::StreamOrDevice s) { |
| | if (axis.has_value()) { |
| | return mx::put_along_axis(a, indices, values, axis.value(), s); |
| | } else { |
| | return mx::reshape( |
| | mx::put_along_axis( |
| | mx::reshape(a, {-1}, s), indices, values, 0, s), |
| | a.shape(), |
| | s); |
| | } |
| | }, |
| | nb::arg(), |
| | "indices"_a, |
| | "values"_a, |
| | "axis"_a.none(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def put_along_axis(a: array, /, indices: array, values: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Put values along an axis at the specified indices. |
| | |
| | Args: |
| | a (array): Destination array. |
| | indices (array): Indices array. These should be broadcastable with |
| | the input array excluding the `axis` dimension. |
| | values (array): Values array. These should be broadcastable with |
| | the indices. |
| | |
| | axis (int or None): Axis in the destination to put the values to. If |
| | ``axis == None`` the destination is flattened prior to the put |
| | operation. |
| | |
| | Returns: |
| | array: The output array. |
| | )pbdoc"); |
| | m.def( |
| | "full", |
| | [](const std::variant<int, mx::Shape>& shape, |
| | const ScalarOrArray& vals, |
| | std::optional<mx::Dtype> dtype, |
| | mx::StreamOrDevice s) { |
| | if (auto pv = std::get_if<int>(&shape); pv) { |
| | return mx::full({*pv}, to_array(vals, dtype), s); |
| | } else { |
| | return mx::full(std::get<mx::Shape>(shape), to_array(vals, dtype), s); |
| | } |
| | }, |
| | "shape"_a, |
| | "vals"_a, |
| | "dtype"_a = nb::none(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def full(shape: Union[int, Sequence[int]], vals: Union[scalar, array], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Construct an array with the given value. |
| | |
| | Constructs an array of size ``shape`` filled with ``vals``. If ``vals`` |
| | is an :obj:`array` it must be broadcastable to the given ``shape``. |
| | |
| | Args: |
| | shape (int or list(int)): The shape of the output array. |
| | vals (float or int or array): Values to fill the array with. |
| | dtype (Dtype, optional): Data type of the output array. If |
| | unspecified the output type is inferred from ``vals``. |
| | |
| | Returns: |
| | array: The output array with the specified shape and values. |
| | )pbdoc"); |
| | m.def( |
| | "zeros", |
| | [](const std::variant<int, mx::Shape>& shape, |
| | std::optional<mx::Dtype> dtype, |
| | mx::StreamOrDevice s) { |
| | auto t = dtype.value_or(mx::float32); |
| | if (auto pv = std::get_if<int>(&shape); pv) { |
| | return mx::zeros({*pv}, t, s); |
| | } else { |
| | return mx::zeros(std::get<mx::Shape>(shape), t, s); |
| | } |
| | }, |
| | "shape"_a, |
| | "dtype"_a.none() = mx::float32, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def zeros(shape: Union[int, Sequence[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Construct an array of zeros. |
| | |
| | Args: |
| | shape (int or list(int)): The shape of the output array. |
| | dtype (Dtype, optional): Data type of the output array. If |
| | unspecified the output type defaults to ``float32``. |
| | |
| | Returns: |
| | array: The array of zeros with the specified shape. |
| | )pbdoc"); |
| | m.def( |
| | "zeros_like", |
| | &mx::zeros_like, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def zeros_like(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | An array of zeros like the input. |
| | |
| | Args: |
| | a (array): The input to take the shape and type from. |
| | |
| | Returns: |
| | array: The output array filled with zeros. |
| | )pbdoc"); |
| | m.def( |
| | "ones", |
| | [](const std::variant<int, mx::Shape>& shape, |
| | std::optional<mx::Dtype> dtype, |
| | mx::StreamOrDevice s) { |
| | auto t = dtype.value_or(mx::float32); |
| | if (auto pv = std::get_if<int>(&shape); pv) { |
| | return mx::ones({*pv}, t, s); |
| | } else { |
| | return mx::ones(std::get<mx::Shape>(shape), t, s); |
| | } |
| | }, |
| | "shape"_a, |
| | "dtype"_a.none() = mx::float32, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def ones(shape: Union[int, Sequence[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Construct an array of ones. |
| | |
| | Args: |
| | shape (int or list(int)): The shape of the output array. |
| | dtype (Dtype, optional): Data type of the output array. If |
| | unspecified the output type defaults to ``float32``. |
| | |
| | Returns: |
| | array: The array of ones with the specified shape. |
| | )pbdoc"); |
| | m.def( |
| | "ones_like", |
| | &mx::ones_like, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def ones_like(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | An array of ones like the input. |
| | |
| | Args: |
| | a (array): The input to take the shape and type from. |
| | |
| | Returns: |
| | array: The output array filled with ones. |
| | )pbdoc"); |
| | m.def( |
| | "eye", |
| | [](int n, |
| | std::optional<int> m, |
| | int k, |
| | std::optional<mx::Dtype> dtype, |
| | mx::StreamOrDevice s) { |
| | return mx::eye(n, m.value_or(n), k, dtype.value_or(mx::float32), s); |
| | }, |
| | "n"_a, |
| | "m"_a = nb::none(), |
| | "k"_a = 0, |
| | "dtype"_a.none() = mx::float32, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def eye(n: int, m: Optional[int] = None, k: int = 0, dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Create an identity matrix or a general diagonal matrix. |
| | |
| | Args: |
| | n (int): The number of rows in the output. |
| | m (int, optional): The number of columns in the output. Defaults to n. |
| | k (int, optional): Index of the diagonal. Defaults to 0 (main diagonal). |
| | dtype (Dtype, optional): Data type of the output array. Defaults to float32. |
| | stream (Stream, optional): Stream or device. Defaults to None. |
| | |
| | Returns: |
| | array: An array where all elements are equal to zero, except for the k-th diagonal, whose values are equal to one. |
| | )pbdoc"); |
| | m.def( |
| | "identity", |
| | [](int n, std::optional<mx::Dtype> dtype, mx::StreamOrDevice s) { |
| | return mx::identity(n, dtype.value_or(mx::float32), s); |
| | }, |
| | "n"_a, |
| | "dtype"_a.none() = mx::float32, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def identity(n: int, dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Create a square identity matrix. |
| | |
| | Args: |
| | n (int): The number of rows and columns in the output. |
| | dtype (Dtype, optional): Data type of the output array. Defaults to float32. |
| | stream (Stream, optional): Stream or device. Defaults to None. |
| | |
| | Returns: |
| | array: An identity matrix of size n x n. |
| | )pbdoc"); |
| | m.def( |
| | "tri", |
| | [](int n, |
| | std::optional<int> m, |
| | int k, |
| | std::optional<mx::Dtype> type, |
| | mx::StreamOrDevice s) { |
| | return mx::tri(n, m.value_or(n), k, type.value_or(mx::float32), s); |
| | }, |
| | "n"_a, |
| | "m"_a = nb::none(), |
| | "k"_a = 0, |
| | "dtype"_a.none() = mx::float32, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def tri(n: int, m: int, k: int, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | An array with ones at and below the given diagonal and zeros elsewhere. |
| | |
| | Args: |
| | n (int): The number of rows in the output. |
| | m (int, optional): The number of cols in the output. Defaults to ``None``. |
| | k (int, optional): The diagonal of the 2-D array. Defaults to ``0``. |
| | dtype (Dtype, optional): Data type of the output array. Defaults to ``float32``. |
| | stream (Stream, optional): Stream or device. Defaults to ``None``. |
| | |
| | Returns: |
| | array: Array with its lower triangle filled with ones and zeros elsewhere |
| | )pbdoc"); |
| | m.def( |
| | "tril", |
| | &mx::tril, |
| | "x"_a, |
| | "k"_a = 0, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def tril(x: array, k: int, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Zeros the array above the given diagonal. |
| | |
| | Args: |
| | x (array): input array. |
| | k (int, optional): The diagonal of the 2-D array. Defaults to ``0``. |
| | stream (Stream, optional): Stream or device. Defaults to ``None``. |
| | |
| | Returns: |
| | array: Array zeroed above the given diagonal |
| | )pbdoc"); |
| | m.def( |
| | "triu", |
| | &mx::triu, |
| | "x"_a, |
| | "k"_a = 0, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def triu(x: array, k: int, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Zeros the array below the given diagonal. |
| | |
| | Args: |
| | x (array): input array. |
| | k (int, optional): The diagonal of the 2-D array. Defaults to ``0``. |
| | stream (Stream, optional): Stream or device. Defaults to ``None``. |
| | |
| | Returns: |
| | array: Array zeroed below the given diagonal |
| | )pbdoc"); |
| | m.def( |
| | "allclose", |
| | &mx::allclose, |
| | nb::arg(), |
| | nb::arg(), |
| | "rtol"_a = 1e-5, |
| | "atol"_a = 1e-8, |
| | nb::kw_only(), |
| | "equal_nan"_a = false, |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def allclose(a: array, b: array, /, rtol: float = 1e-05, atol: float = 1e-08, *, equal_nan: bool = False, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Approximate comparison of two arrays. |
| | |
| | Infinite values are considered equal if they have the same sign, NaN values are not equal unless ``equal_nan`` is ``True``. |
| | |
| | The arrays are considered equal if: |
| | |
| | .. code-block:: |
| | |
| | all(abs(a - b) <= (atol + rtol * abs(b))) |
| | |
| | Note unlike :func:`array_equal`, this function supports numpy-style |
| | broadcasting. |
| | |
| | Args: |
| | a (array): Input array. |
| | b (array): Input array. |
| | rtol (float): Relative tolerance. |
| | atol (float): Absolute tolerance. |
| | equal_nan (bool): If ``True``, NaNs are considered equal. |
| | Defaults to ``False``. |
| | |
| | Returns: |
| | array: The boolean output scalar indicating if the arrays are close. |
| | )pbdoc"); |
| | m.def( |
| | "isclose", |
| | &mx::isclose, |
| | nb::arg(), |
| | nb::arg(), |
| | "rtol"_a = 1e-5, |
| | "atol"_a = 1e-8, |
| | nb::kw_only(), |
| | "equal_nan"_a = false, |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def isclose(a: array, b: array, /, rtol: float = 1e-05, atol: float = 1e-08, *, equal_nan: bool = False, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Returns a boolean array where two arrays are element-wise equal within a tolerance. |
| | |
| | Infinite values are considered equal if they have the same sign, NaN values are |
| | not equal unless ``equal_nan`` is ``True``. |
| | |
| | Two values are considered equal if: |
| | |
| | .. code-block:: |
| | |
| | abs(a - b) <= (atol + rtol * abs(b)) |
| | |
| | Note unlike :func:`array_equal`, this function supports numpy-style |
| | broadcasting. |
| | |
| | Args: |
| | a (array): Input array. |
| | b (array): Input array. |
| | rtol (float): Relative tolerance. |
| | atol (float): Absolute tolerance. |
| | equal_nan (bool): If ``True``, NaNs are considered equal. |
| | Defaults to ``False``. |
| | |
| | Returns: |
| | array: The boolean output scalar indicating if the arrays are close. |
| | )pbdoc"); |
| | m.def( |
| | "all", |
| | [](const mx::array& a, |
| | const IntOrVec& axis, |
| | bool keepdims, |
| | mx::StreamOrDevice s) { |
| | return mx::all(a, get_reduce_axes(axis, a.ndim()), keepdims, s); |
| | }, |
| | nb::arg(), |
| | "axis"_a = nb::none(), |
| | "keepdims"_a = false, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def all(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | An `and` reduction over the given axes. |
| | |
| | Args: |
| | a (array): Input array. |
| | axis (int or list(int), optional): Optional axis or |
| | axes to reduce over. If unspecified this defaults |
| | to reducing over the entire array. |
| | keepdims (bool, optional): Keep reduced axes as |
| | singleton dimensions, defaults to `False`. |
| | |
| | Returns: |
| | array: The output array with the corresponding axes reduced. |
| | )pbdoc"); |
| | m.def( |
| | "any", |
| | [](const mx::array& a, |
| | const IntOrVec& axis, |
| | bool keepdims, |
| | mx::StreamOrDevice s) { |
| | return mx::any(a, get_reduce_axes(axis, a.ndim()), keepdims, s); |
| | }, |
| | nb::arg(), |
| | "axis"_a = nb::none(), |
| | "keepdims"_a = false, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def any(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | An `or` reduction over the given axes. |
| | |
| | Args: |
| | a (array): Input array. |
| | axis (int or list(int), optional): Optional axis or |
| | axes to reduce over. If unspecified this defaults |
| | to reducing over the entire array. |
| | keepdims (bool, optional): Keep reduced axes as |
| | singleton dimensions, defaults to `False`. |
| | |
| | Returns: |
| | array: The output array with the corresponding axes reduced. |
| | )pbdoc"); |
| | m.def( |
| | "minimum", |
| | [](const ScalarOrArray& a_, |
| | const ScalarOrArray& b_, |
| | mx::StreamOrDevice s) { |
| | auto [a, b] = to_arrays(a_, b_); |
| | return mx::minimum(a, b, s); |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def minimum(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise minimum. |
| | |
| | Take the element-wise min of two arrays with numpy-style broadcasting |
| | semantics. Either or both input arrays can also be scalars. |
| | |
| | Args: |
| | a (array): Input array or scalar. |
| | b (array): Input array or scalar. |
| | |
| | Returns: |
| | array: The min of ``a`` and ``b``. |
| | )pbdoc"); |
| | m.def( |
| | "maximum", |
| | [](const ScalarOrArray& a_, |
| | const ScalarOrArray& b_, |
| | mx::StreamOrDevice s) { |
| | auto [a, b] = to_arrays(a_, b_); |
| | return mx::maximum(a, b, s); |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def maximum(a: Union[scalar, array], b: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise maximum. |
| | |
| | Take the element-wise max of two arrays with numpy-style broadcasting |
| | semantics. Either or both input arrays can also be scalars. |
| | |
| | Args: |
| | a (array): Input array or scalar. |
| | b (array): Input array or scalar. |
| | |
| | Returns: |
| | array: The max of ``a`` and ``b``. |
| | )pbdoc"); |
| | m.def( |
| | "floor", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::floor(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def floor(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise floor. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The floor of ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "ceil", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::ceil(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def ceil(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise ceil. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The ceil of ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "isnan", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::isnan(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def isnan(a: array, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Return a boolean array indicating which elements are NaN. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The boolean array indicating which elements are NaN. |
| | )pbdoc"); |
| | m.def( |
| | "isinf", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::isinf(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def isinf(a: array, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Return a boolean array indicating which elements are +/- inifnity. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The boolean array indicating which elements are +/- infinity. |
| | )pbdoc"); |
| | m.def( |
| | "isfinite", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::isfinite(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def isfinite(a: array, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Return a boolean array indicating which elements are finite. |
| | |
| | An element is finite if it is not infinite or NaN. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The boolean array indicating which elements are finite. |
| | )pbdoc"); |
| | m.def( |
| | "isposinf", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::isposinf(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def isposinf(a: array, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Return a boolean array indicating which elements are positive infinity. |
| | |
| | Args: |
| | a (array): Input array. |
| | stream (Union[None, Stream, Device]): Optional stream or device. |
| | |
| | Returns: |
| | array: The boolean array indicating which elements are positive infinity. |
| | )pbdoc"); |
| | m.def( |
| | "isneginf", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::isneginf(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def isneginf(a: array, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Return a boolean array indicating which elements are negative infinity. |
| | |
| | Args: |
| | a (array): Input array. |
| | stream (Union[None, Stream, Device]): Optional stream or device. |
| | |
| | Returns: |
| | array: The boolean array indicating which elements are negative infinity. |
| | )pbdoc"); |
| | m.def( |
| | "moveaxis", |
| | &mx::moveaxis, |
| | nb::arg(), |
| | "source"_a, |
| | "destination"_a, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def moveaxis(a: array, /, source: int, destination: int, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Move an axis to a new position. |
| | |
| | Args: |
| | a (array): Input array. |
| | source (int): Specifies the source axis. |
| | destination (int): Specifies the destination axis. |
| | |
| | Returns: |
| | array: The array with the axis moved. |
| | )pbdoc"); |
| | m.def( |
| | "swapaxes", |
| | &mx::swapaxes, |
| | nb::arg(), |
| | "axis1"_a, |
| | "axis2"_a, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def swapaxes(a: array, /, axis1 : int, axis2: int, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Swap two axes of an array. |
| | |
| | Args: |
| | a (array): Input array. |
| | axis1 (int): Specifies the first axis. |
| | axis2 (int): Specifies the second axis. |
| | |
| | Returns: |
| | array: The array with swapped axes. |
| | )pbdoc"); |
| | m.def( |
| | "transpose", |
| | [](const mx::array& a, |
| | const std::optional<std::vector<int>>& axes, |
| | mx::StreamOrDevice s) { |
| | if (axes.has_value()) { |
| | return mx::transpose(a, *axes, s); |
| | } else { |
| | return mx::transpose(a, s); |
| | } |
| | }, |
| | nb::arg(), |
| | "axes"_a = nb::none(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def transpose(a: array, /, axes: Optional[Sequence[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Transpose the dimensions of the array. |
| | |
| | Args: |
| | a (array): Input array. |
| | axes (list(int), optional): Specifies the source axis for each axis |
| | in the new array. The default is to reverse the axes. |
| | |
| | Returns: |
| | array: The transposed array. |
| | )pbdoc"); |
| | m.def( |
| | "permute_dims", |
| | [](const mx::array& a, |
| | const std::optional<std::vector<int>>& axes, |
| | mx::StreamOrDevice s) { |
| | if (axes.has_value()) { |
| | return mx::transpose(a, *axes, s); |
| | } else { |
| | return mx::transpose(a, s); |
| | } |
| | }, |
| | nb::arg(), |
| | "axes"_a = nb::none(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def permute_dims(a: array, /, axes: Optional[Sequence[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | See :func:`transpose`. |
| | )pbdoc"); |
| | m.def( |
| | "sum", |
| | [](const mx::array& a, |
| | const IntOrVec& axis, |
| | bool keepdims, |
| | mx::StreamOrDevice s) { |
| | return mx::sum(a, get_reduce_axes(axis, a.ndim()), keepdims, s); |
| | }, |
| | "array"_a, |
| | "axis"_a = nb::none(), |
| | "keepdims"_a = false, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def sum(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Sum reduce the array over the given axes. |
| | |
| | Args: |
| | a (array): Input array. |
| | axis (int or list(int), optional): Optional axis or |
| | axes to reduce over. If unspecified this defaults |
| | to reducing over the entire array. |
| | keepdims (bool, optional): Keep reduced axes as |
| | singleton dimensions, defaults to `False`. |
| | |
| | Returns: |
| | array: The output array with the corresponding axes reduced. |
| | )pbdoc"); |
| | m.def( |
| | "prod", |
| | [](const mx::array& a, |
| | const IntOrVec& axis, |
| | bool keepdims, |
| | mx::StreamOrDevice s) { |
| | return mx::prod(a, get_reduce_axes(axis, a.ndim()), keepdims, s); |
| | }, |
| | nb::arg(), |
| | "axis"_a = nb::none(), |
| | "keepdims"_a = false, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def prod(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | An product reduction over the given axes. |
| | |
| | Args: |
| | a (array): Input array. |
| | axis (int or list(int), optional): Optional axis or |
| | axes to reduce over. If unspecified this defaults |
| | to reducing over the entire array. |
| | keepdims (bool, optional): Keep reduced axes as |
| | singleton dimensions, defaults to `False`. |
| | |
| | Returns: |
| | array: The output array with the corresponding axes reduced. |
| | )pbdoc"); |
| | m.def( |
| | "min", |
| | [](const mx::array& a, |
| | const IntOrVec& axis, |
| | bool keepdims, |
| | mx::StreamOrDevice s) { |
| | return mx::min(a, get_reduce_axes(axis, a.ndim()), keepdims, s); |
| | }, |
| | nb::arg(), |
| | "axis"_a = nb::none(), |
| | "keepdims"_a = false, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def min(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | A `min` reduction over the given axes. |
| | |
| | Args: |
| | a (array): Input array. |
| | axis (int or list(int), optional): Optional axis or |
| | axes to reduce over. If unspecified this defaults |
| | to reducing over the entire array. |
| | keepdims (bool, optional): Keep reduced axes as |
| | singleton dimensions, defaults to `False`. |
| | |
| | Returns: |
| | array: The output array with the corresponding axes reduced. |
| | )pbdoc"); |
| | m.def( |
| | "max", |
| | [](const mx::array& a, |
| | const IntOrVec& axis, |
| | bool keepdims, |
| | mx::StreamOrDevice s) { |
| | return mx::max(a, get_reduce_axes(axis, a.ndim()), keepdims, s); |
| | }, |
| | nb::arg(), |
| | "axis"_a = nb::none(), |
| | "keepdims"_a = false, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def max(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | A `max` reduction over the given axes. |
| | |
| | Args: |
| | a (array): Input array. |
| | axis (int or list(int), optional): Optional axis or |
| | axes to reduce over. If unspecified this defaults |
| | to reducing over the entire array. |
| | keepdims (bool, optional): Keep reduced axes as |
| | singleton dimensions, defaults to `False`. |
| | |
| | Returns: |
| | array: The output array with the corresponding axes reduced. |
| | )pbdoc"); |
| | m.def( |
| | "logcumsumexp", |
| | [](const mx::array& a, |
| | std::optional<int> axis, |
| | bool reverse, |
| | bool inclusive, |
| | mx::StreamOrDevice s) { |
| | if (axis) { |
| | return mx::logcumsumexp(a, *axis, reverse, inclusive, s); |
| | } else { |
| | return mx::logcumsumexp( |
| | mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); |
| | } |
| | }, |
| | nb::arg(), |
| | "axis"_a = nb::none(), |
| | nb::kw_only(), |
| | "reverse"_a = false, |
| | "inclusive"_a = true, |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def logcumsumexp(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Return the cumulative logsumexp of the elements along the given axis. |
| | |
| | Args: |
| | a (array): Input array |
| | axis (int, optional): Optional axis to compute the cumulative logsumexp |
| | over. If unspecified the cumulative logsumexp of the flattened array is |
| | returned. |
| | reverse (bool): Perform the cumulative logsumexp in reverse. |
| | inclusive (bool): The i-th element of the output includes the i-th |
| | element of the input. |
| | |
| | Returns: |
| | array: The output array. |
| | )pbdoc"); |
| | m.def( |
| | "logsumexp", |
| | [](const mx::array& a, |
| | const IntOrVec& axis, |
| | bool keepdims, |
| | mx::StreamOrDevice s) { |
| | return mx::logsumexp(a, get_reduce_axes(axis, a.ndim()), keepdims, s); |
| | }, |
| | nb::arg(), |
| | "axis"_a = nb::none(), |
| | "keepdims"_a = false, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def logsumexp(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | A `log-sum-exp` reduction over the given axes. |
| | |
| | The log-sum-exp reduction is a numerically stable version of: |
| | |
| | .. code-block:: |
| | |
| | log(sum(exp(a), axis)) |
| | |
| | Args: |
| | a (array): Input array. |
| | axis (int or list(int), optional): Optional axis or |
| | axes to reduce over. If unspecified this defaults |
| | to reducing over the entire array. |
| | keepdims (bool, optional): Keep reduced axes as |
| | singleton dimensions, defaults to `False`. |
| | |
| | Returns: |
| | array: The output array with the corresponding axes reduced. |
| | )pbdoc"); |
| | m.def( |
| | "mean", |
| | [](const mx::array& a, |
| | const IntOrVec& axis, |
| | bool keepdims, |
| | mx::StreamOrDevice s) { |
| | return mx::mean(a, get_reduce_axes(axis, a.ndim()), keepdims, s); |
| | }, |
| | nb::arg(), |
| | "axis"_a = nb::none(), |
| | "keepdims"_a = false, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def mean(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Compute the mean(s) over the given axes. |
| | |
| | Args: |
| | a (array): Input array. |
| | axis (int or list(int), optional): Optional axis or |
| | axes to reduce over. If unspecified this defaults |
| | to reducing over the entire array. |
| | keepdims (bool, optional): Keep reduced axes as |
| | singleton dimensions, defaults to `False`. |
| | |
| | Returns: |
| | array: The output array of means. |
| | )pbdoc"); |
| | m.def( |
| | "var", |
| | [](const mx::array& a, |
| | const IntOrVec& axis, |
| | bool keepdims, |
| | int ddof, |
| | mx::StreamOrDevice s) { |
| | return mx::var(a, get_reduce_axes(axis, a.ndim()), keepdims, ddof, s); |
| | }, |
| | nb::arg(), |
| | "axis"_a = nb::none(), |
| | "keepdims"_a = false, |
| | "ddof"_a = 0, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def var(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, ddof: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Compute the variance(s) over the given axes. |
| | |
| | Args: |
| | a (array): Input array. |
| | axis (int or list(int), optional): Optional axis or |
| | axes to reduce over. If unspecified this defaults |
| | to reducing over the entire array. |
| | keepdims (bool, optional): Keep reduced axes as |
| | singleton dimensions, defaults to `False`. |
| | ddof (int, optional): The divisor to compute the variance |
| | is ``N - ddof``, defaults to 0. |
| | |
| | Returns: |
| | array: The output array of variances. |
| | )pbdoc"); |
| | m.def( |
| | "std", |
| | [](const mx::array& a, |
| | const IntOrVec& axis, |
| | bool keepdims, |
| | int ddof, |
| | mx::StreamOrDevice s) { |
| | return mx::std(a, get_reduce_axes(axis, a.ndim()), keepdims, ddof, s); |
| | }, |
| | nb::arg(), |
| | "axis"_a = nb::none(), |
| | "keepdims"_a = false, |
| | "ddof"_a = 0, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def std(a: array, /, axis: Union[None, int, Sequence[int]] = None, keepdims: bool = False, ddof: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Compute the standard deviation(s) over the given axes. |
| | |
| | Args: |
| | a (array): Input array. |
| | axis (int or list(int), optional): Optional axis or |
| | axes to reduce over. If unspecified this defaults |
| | to reducing over the entire array. |
| | keepdims (bool, optional): Keep reduced axes as |
| | singleton dimensions, defaults to `False`. |
| | ddof (int, optional): The divisor to compute the variance |
| | is ``N - ddof``, defaults to 0. |
| | |
| | Returns: |
| | array: The output array of standard deviations. |
| | )pbdoc"); |
| | m.def( |
| | "split", |
| | [](const mx::array& a, |
| | const std::variant<int, mx::Shape>& indices_or_sections, |
| | int axis, |
| | mx::StreamOrDevice s) { |
| | if (auto pv = std::get_if<int>(&indices_or_sections); pv) { |
| | return mx::split(a, *pv, axis, s); |
| | } else { |
| | return mx::split( |
| | a, std::get<mx::Shape>(indices_or_sections), axis, s); |
| | } |
| | }, |
| | nb::arg(), |
| | "indices_or_sections"_a, |
| | "axis"_a = 0, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def split(a: array, /, indices_or_sections: Union[int, Sequence[int]], axis: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Split an array along a given axis. |
| | |
| | Args: |
| | a (array): Input array. |
| | indices_or_sections (int or list(int)): If ``indices_or_sections`` |
| | is an integer the array is split into that many sections of equal |
| | size. An error is raised if this is not possible. If ``indices_or_sections`` |
| | is a list, the list contains the indices of the start of each subarray |
| | along the given axis. |
| | axis (int, optional): Axis to split along, defaults to `0`. |
| | |
| | Returns: |
| | list(array): A list of split arrays. |
| | )pbdoc"); |
| | m.def( |
| | "argmin", |
| | [](const mx::array& a, |
| | std::optional<int> axis, |
| | bool keepdims, |
| | mx::StreamOrDevice s) { |
| | if (axis) { |
| | return mx::argmin(a, *axis, keepdims, s); |
| | } else { |
| | return mx::argmin(a, keepdims, s); |
| | } |
| | }, |
| | nb::arg(), |
| | "axis"_a = nb::none(), |
| | "keepdims"_a = false, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def argmin(a: array, /, axis: Union[None, int] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Indices of the minimum values along the axis. |
| | |
| | Args: |
| | a (array): Input array. |
| | axis (int, optional): Optional axis to reduce over. If unspecified |
| | this defaults to reducing over the entire array. |
| | keepdims (bool, optional): Keep reduced axes as |
| | singleton dimensions, defaults to `False`. |
| | |
| | Returns: |
| | array: The ``uint32`` array with the indices of the minimum values. |
| | )pbdoc"); |
| | m.def( |
| | "argmax", |
| | [](const mx::array& a, |
| | std::optional<int> axis, |
| | bool keepdims, |
| | mx::StreamOrDevice s) { |
| | if (axis) { |
| | return mx::argmax(a, *axis, keepdims, s); |
| | } else { |
| | return mx::argmax(a, keepdims, s); |
| | } |
| | }, |
| | nb::arg(), |
| | "axis"_a = nb::none(), |
| | "keepdims"_a = false, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def argmax(a: array, /, axis: Union[None, int] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Indices of the maximum values along the axis. |
| | |
| | Args: |
| | a (array): Input array. |
| | axis (int, optional): Optional axis to reduce over. If unspecified |
| | this defaults to reducing over the entire array. |
| | keepdims (bool, optional): Keep reduced axes as |
| | singleton dimensions, defaults to `False`. |
| | |
| | Returns: |
| | array: The ``uint32`` array with the indices of the maximum values. |
| | )pbdoc"); |
| | m.def( |
| | "sort", |
| | [](const mx::array& a, std::optional<int> axis, mx::StreamOrDevice s) { |
| | if (axis) { |
| | return mx::sort(a, *axis, s); |
| | } else { |
| | return mx::sort(a, s); |
| | } |
| | }, |
| | nb::arg(), |
| | "axis"_a.none() = -1, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def sort(a: array, /, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Returns a sorted copy of the array. |
| | |
| | Args: |
| | a (array): Input array. |
| | axis (int or None, optional): Optional axis to sort over. |
| | If ``None``, this sorts over the flattened array. |
| | If unspecified, it defaults to -1 (sorting over the last axis). |
| | |
| | Returns: |
| | array: The sorted array. |
| | )pbdoc"); |
| | m.def( |
| | "argsort", |
| | [](const mx::array& a, std::optional<int> axis, mx::StreamOrDevice s) { |
| | if (axis) { |
| | return mx::argsort(a, *axis, s); |
| | } else { |
| | return mx::argsort(a, s); |
| | } |
| | }, |
| | nb::arg(), |
| | "axis"_a.none() = -1, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def argsort(a: array, /, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Returns the indices that sort the array. |
| | |
| | Args: |
| | a (array): Input array. |
| | axis (int or None, optional): Optional axis to sort over. |
| | If ``None``, this sorts over the flattened array. |
| | If unspecified, it defaults to -1 (sorting over the last axis). |
| | |
| | Returns: |
| | array: The ``uint32`` array containing indices that sort the input. |
| | )pbdoc"); |
| | m.def( |
| | "partition", |
| | [](const mx::array& a, |
| | int kth, |
| | std::optional<int> axis, |
| | mx::StreamOrDevice s) { |
| | if (axis) { |
| | return mx::partition(a, kth, *axis, s); |
| | } else { |
| | return mx::partition(a, kth, s); |
| | } |
| | }, |
| | nb::arg(), |
| | "kth"_a, |
| | "axis"_a.none() = -1, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def partition(a: array, /, kth: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Returns a partitioned copy of the array such that the smaller ``kth`` |
| | elements are first. |
| | |
| | The ordering of the elements in partitions is undefined. |
| | |
| | Args: |
| | a (array): Input array. |
| | kth (int): Element at the ``kth`` index will be in its sorted |
| | position in the output. All elements before the kth index will |
| | be less or equal to the ``kth`` element and all elements after |
| | will be greater or equal to the ``kth`` element in the output. |
| | axis (int or None, optional): Optional axis to partition over. |
| | If ``None``, this partitions over the flattened array. |
| | If unspecified, it defaults to ``-1``. |
| | |
| | Returns: |
| | array: The partitioned array. |
| | )pbdoc"); |
| | m.def( |
| | "argpartition", |
| | [](const mx::array& a, |
| | int kth, |
| | std::optional<int> axis, |
| | mx::StreamOrDevice s) { |
| | if (axis) { |
| | return mx::argpartition(a, kth, *axis, s); |
| | } else { |
| | return mx::argpartition(a, kth, s); |
| | } |
| | }, |
| | nb::arg(), |
| | "kth"_a, |
| | "axis"_a.none() = -1, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def argpartition(a: array, /, kth: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Returns the indices that partition the array. |
| | |
| | The ordering of the elements within a partition in given by the indices |
| | is undefined. |
| | |
| | Args: |
| | a (array): Input array. |
| | kth (int): Element index at the ``kth`` position in the output will |
| | give the sorted position. All indices before the ``kth`` position |
| | will be of elements less or equal to the element at the ``kth`` |
| | index and all indices after will be of elements greater or equal |
| | to the element at the ``kth`` index. |
| | axis (int or None, optional): Optional axis to partition over. |
| | If ``None``, this partitions over the flattened array. |
| | If unspecified, it defaults to ``-1``. |
| | |
| | Returns: |
| | array: The ``uint32`` array containing indices that partition the input. |
| | )pbdoc"); |
| | m.def( |
| | "topk", |
| | [](const mx::array& a, |
| | int k, |
| | std::optional<int> axis, |
| | mx::StreamOrDevice s) { |
| | if (axis) { |
| | return mx::topk(a, k, *axis, s); |
| | } else { |
| | return mx::topk(a, k, s); |
| | } |
| | }, |
| | nb::arg(), |
| | "k"_a, |
| | "axis"_a.none() = -1, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def topk(a: array, /, k: int, axis: Union[None, int] = -1, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Returns the ``k`` largest elements from the input along a given axis. |
| | |
| | The elements will not necessarily be in sorted order. |
| | |
| | Args: |
| | a (array): Input array. |
| | k (int): ``k`` top elements to be returned |
| | axis (int or None, optional): Optional axis to select over. |
| | If ``None``, this selects the top ``k`` elements over the |
| | flattened array. If unspecified, it defaults to ``-1``. |
| | |
| | Returns: |
| | array: The top ``k`` elements from the input. |
| | )pbdoc"); |
| | m.def( |
| | "broadcast_to", |
| | [](const ScalarOrArray& a, const mx::Shape& shape, mx::StreamOrDevice s) { |
| | return mx::broadcast_to(to_array(a), shape, s); |
| | }, |
| | nb::arg(), |
| | "shape"_a, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def broadcast_to(a: Union[scalar, array], /, shape: Sequence[int], *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Broadcast an array to the given shape. |
| | |
| | The broadcasting semantics are the same as Numpy. |
| | |
| | Args: |
| | a (array): Input array. |
| | shape (list(int)): The shape to broadcast to. |
| | |
| | Returns: |
| | array: The output array with the new shape. |
| | )pbdoc"); |
| | m.def( |
| | "broadcast_arrays", |
| | [](const nb::args& args, mx::StreamOrDevice s) { |
| | return broadcast_arrays(nb::cast<std::vector<mx::array>>(args), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def broadcast_arrays(*arrays: array, stream: Union[None, Stream, Device] = None) -> Tuple[array, ...]"), |
| | R"pbdoc( |
| | Broadcast arrays against one another. |
| | |
| | The broadcasting semantics are the same as Numpy. |
| | |
| | Args: |
| | *arrays (array): The input arrays. |
| | |
| | Returns: |
| | tuple(array): The output arrays with the broadcasted shape. |
| | )pbdoc"); |
| | m.def( |
| | "softmax", |
| | [](const mx::array& a, |
| | const IntOrVec& axis, |
| | bool precise, |
| | mx::StreamOrDevice s) { |
| | return mx::softmax(a, get_reduce_axes(axis, a.ndim()), precise, s); |
| | }, |
| | nb::arg(), |
| | "axis"_a = nb::none(), |
| | nb::kw_only(), |
| | "precise"_a = false, |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def softmax(a: array, /, axis: Union[None, int, Sequence[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Perform the softmax along the given axis. |
| | |
| | This operation is a numerically stable version of: |
| | |
| | .. code-block:: |
| | |
| | exp(a) / sum(exp(a), axis, keepdims=True) |
| | |
| | Args: |
| | a (array): Input array. |
| | axis (int or list(int), optional): Optional axis or axes to compute |
| | the softmax over. If unspecified this performs the softmax over |
| | the full array. |
| | |
| | Returns: |
| | array: The output of the softmax. |
| | )pbdoc"); |
| | m.def( |
| | "concatenate", |
| | [](const std::vector<mx::array>& arrays, |
| | std::optional<int> axis, |
| | mx::StreamOrDevice s) { |
| | if (axis) { |
| | return mx::concatenate(arrays, *axis, s); |
| | } else { |
| | return mx::concatenate(arrays, s); |
| | } |
| | }, |
| | nb::arg(), |
| | "axis"_a.none() = 0, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def concatenate(arrays: list[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Concatenate the arrays along the given axis. |
| | |
| | Args: |
| | arrays (list(array)): Input :obj:`list` or :obj:`tuple` of arrays. |
| | axis (int, optional): Optional axis to concatenate along. If |
| | unspecified defaults to ``0``. |
| | |
| | Returns: |
| | array: The concatenated array. |
| | )pbdoc"); |
| | m.def( |
| | "concat", |
| | [](const std::vector<mx::array>& arrays, |
| | std::optional<int> axis, |
| | mx::StreamOrDevice s) { |
| | if (axis) { |
| | return mx::concatenate(arrays, *axis, s); |
| | } else { |
| | return mx::concatenate(arrays, s); |
| | } |
| | }, |
| | nb::arg(), |
| | "axis"_a.none() = 0, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def concat(arrays: list[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | See :func:`concatenate`. |
| | )pbdoc"); |
| | m.def( |
| | "stack", |
| | [](const std::vector<mx::array>& arrays, |
| | std::optional<int> axis, |
| | mx::StreamOrDevice s) { |
| | if (axis.has_value()) { |
| | return mx::stack(arrays, axis.value(), s); |
| | } else { |
| | return mx::stack(arrays, s); |
| | } |
| | }, |
| | nb::arg(), |
| | "axis"_a = 0, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def stack(arrays: list[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Stacks the arrays along a new axis. |
| | |
| | Args: |
| | arrays (list(array)): A list of arrays to stack. |
| | axis (int, optional): The axis in the result array along which the |
| | input arrays are stacked. Defaults to ``0``. |
| | stream (Stream, optional): Stream or device. Defaults to ``None``. |
| | |
| | Returns: |
| | array: The resulting stacked array. |
| | )pbdoc"); |
| | m.def( |
| | "meshgrid", |
| | [](nb::args arrays_, |
| | bool sparse, |
| | std::string indexing, |
| | mx::StreamOrDevice s) { |
| | std::vector<mx::array> arrays = |
| | nb::cast<std::vector<mx::array>>(arrays_); |
| | return mx::meshgrid(arrays, sparse, indexing, s); |
| | }, |
| | "arrays"_a, |
| | "sparse"_a = false, |
| | "indexing"_a = "xy", |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def meshgrid(*arrays: array, sparse: Optional[bool] = False, indexing: Optional[str] = 'xy', stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Generate multidimensional coordinate grids from 1-D coordinate arrays |
| | |
| | Args: |
| | *arrays (array): Input arrays. |
| | sparse (bool, optional): If ``True``, a sparse grid is returned in which each output |
| | array has a single non-zero element. If ``False``, a dense grid is returned. |
| | Defaults to ``False``. |
| | indexing (str, optional): Cartesian ('xy') or matrix ('ij') indexing of the output arrays. |
| | Defaults to ``'xy'``. |
| | |
| | Returns: |
| | list(array): The output arrays. |
| | )pbdoc"); |
| | m.def( |
| | "repeat", |
| | [](const mx::array& array, |
| | int repeats, |
| | std::optional<int> axis, |
| | mx::StreamOrDevice s) { |
| | if (axis.has_value()) { |
| | return mx::repeat(array, repeats, axis.value(), s); |
| | } else { |
| | return mx::repeat(array, repeats, s); |
| | } |
| | }, |
| | nb::arg(), |
| | "repeats"_a, |
| | "axis"_a = nb::none(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def repeat(array: array, repeats: int, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Repeat an array along a specified axis. |
| | |
| | Args: |
| | array (array): Input array. |
| | repeats (int): The number of repetitions for each element. |
| | axis (int, optional): The axis in which to repeat the array along. If |
| | unspecified it uses the flattened array of the input and repeats |
| | along axis 0. |
| | stream (Stream, optional): Stream or device. Defaults to ``None``. |
| | |
| | Returns: |
| | array: The resulting repeated array. |
| | )pbdoc"); |
| | m.def( |
| | "clip", |
| | [](const mx::array& a, |
| | const std::optional<ScalarOrArray>& min, |
| | const std::optional<ScalarOrArray>& max, |
| | mx::StreamOrDevice s) { |
| | std::optional<mx::array> min_ = std::nullopt; |
| | std::optional<mx::array> max_ = std::nullopt; |
| | if (min) { |
| | min_ = to_arrays(a, min.value()).second; |
| | } |
| | if (max) { |
| | max_ = to_arrays(a, max.value()).second; |
| | } |
| | return mx::clip(a, min_, max_, s); |
| | }, |
| | nb::arg(), |
| | "a_min"_a.none(), |
| | "a_max"_a.none(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def clip(a: array, /, a_min: Union[scalar, array, None], a_max: Union[scalar, array, None], *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Clip the values of the array between the given minimum and maximum. |
| | |
| | If either ``a_min`` or ``a_max`` are ``None``, then corresponding edge |
| | is ignored. At least one of ``a_min`` and ``a_max`` cannot be ``None``. |
| | The input ``a`` and the limits must broadcast with one another. |
| | |
| | Args: |
| | a (array): Input array. |
| | a_min (scalar or array or None): Minimum value to clip to. |
| | a_max (scalar or array or None): Maximum value to clip to. |
| | |
| | Returns: |
| | array: The clipped array. |
| | )pbdoc"); |
| | m.def( |
| | "pad", |
| | [](const mx::array& a, |
| | const std::variant< |
| | int, |
| | std::tuple<int>, |
| | std::pair<int, int>, |
| | std::vector<std::pair<int, int>>>& pad_width, |
| | const std::string& mode, |
| | const ScalarOrArray& constant_value, |
| | mx::StreamOrDevice s) { |
| | if (auto pv = std::get_if<int>(&pad_width); pv) { |
| | return mx::pad(a, *pv, to_array(constant_value), mode, s); |
| | } else if (auto pv = std::get_if<std::tuple<int>>(&pad_width); pv) { |
| | return mx::pad( |
| | a, std::get<0>(*pv), to_array(constant_value), mode, s); |
| | } else if (auto pv = std::get_if<std::pair<int, int>>(&pad_width); pv) { |
| | return mx::pad(a, *pv, to_array(constant_value), mode, s); |
| | } else { |
| | auto v = std::get<std::vector<std::pair<int, int>>>(pad_width); |
| | if (v.size() == 1) { |
| | return mx::pad(a, v[0], to_array(constant_value), mode, s); |
| | } else { |
| | return mx::pad(a, v, to_array(constant_value), mode, s); |
| | } |
| | } |
| | }, |
| | nb::arg(), |
| | "pad_width"_a, |
| | "mode"_a = "constant", |
| | "constant_values"_a = 0, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def pad(a: array, pad_width: Union[int, tuple[int], tuple[int, int], list[tuple[int, int]]], mode: Literal['constant', 'edge'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Pad an array with a constant value |
| | |
| | Args: |
| | a (array): Input array. |
| | pad_width (int, tuple(int), tuple(int, int) or list(tuple(int, int))): Number of padded |
| | values to add to the edges of each axis:``((before_1, after_1), |
| | (before_2, after_2), ..., (before_N, after_N))``. If a single pair |
| | of integers is passed then ``(before_i, after_i)`` are all the same. |
| | If a single integer or tuple with a single integer is passed then |
| | all axes are extended by the same number on each side. |
| | mode: Padding mode. One of the following strings: |
| | "constant" (default): Pads with a constant value. |
| | "edge": Pads with the edge values of array. |
| | constant_value (array or scalar, optional): Optional constant value |
| | to pad the edges of the array with. |
| | |
| | Returns: |
| | array: The padded array. |
| | )pbdoc"); |
| | m.def( |
| | "as_strided", |
| | [](const mx::array& a, |
| | std::optional<mx::Shape> shape, |
| | std::optional<mx::Strides> strides, |
| | size_t offset, |
| | mx::StreamOrDevice s) { |
| | auto a_shape = (shape) ? *shape : a.shape(); |
| | mx::Strides a_strides; |
| | if (strides) { |
| | a_strides = *strides; |
| | } else { |
| | a_strides = mx::Strides(a_shape.size(), 1); |
| | for (int i = a_shape.size() - 1; i > 0; i--) { |
| | a_strides[i - 1] = a_shape[i] * a_strides[i]; |
| | } |
| | } |
| | return mx::as_strided(a, a_shape, a_strides, offset, s); |
| | }, |
| | nb::arg(), |
| | "shape"_a = nb::none(), |
| | "strides"_a = nb::none(), |
| | "offset"_a = 0, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def as_strided(a: array, /, shape: Optional[Sequence[int]] = None, strides: Optional[Sequence[int]] = None, offset: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Create a view into the array with the given shape and strides. |
| | |
| | The resulting array will always be as if the provided array was row |
| | contiguous regardless of the provided arrays storage order and current |
| | strides. |
| | |
| | .. note:: |
| | Note that this function should be used with caution as it changes |
| | the shape and strides of the array directly. This can lead to the |
| | resulting array pointing to invalid memory locations which can |
| | result into crashes. |
| | |
| | Args: |
| | a (array): Input array |
| | shape (list(int), optional): The shape of the resulting array. If |
| | None it defaults to ``a.shape()``. |
| | strides (list(int), optional): The strides of the resulting array. If |
| | None it defaults to the reverse exclusive cumulative product of |
| | ``a.shape()``. |
| | offset (int): Skip that many elements from the beginning of the input |
| | array. |
| | |
| | Returns: |
| | array: The output array which is the strided view of the input. |
| | )pbdoc"); |
| | m.def( |
| | "cumsum", |
| | [](const mx::array& a, |
| | std::optional<int> axis, |
| | bool reverse, |
| | bool inclusive, |
| | mx::StreamOrDevice s) { |
| | if (axis) { |
| | return mx::cumsum(a, *axis, reverse, inclusive, s); |
| | } else { |
| | return mx::cumsum(mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); |
| | } |
| | }, |
| | nb::arg(), |
| | "axis"_a = nb::none(), |
| | nb::kw_only(), |
| | "reverse"_a = false, |
| | "inclusive"_a = true, |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def cumsum(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Return the cumulative sum of the elements along the given axis. |
| | |
| | Args: |
| | a (array): Input array |
| | axis (int, optional): Optional axis to compute the cumulative sum |
| | over. If unspecified the cumulative sum of the flattened array is |
| | returned. |
| | reverse (bool): Perform the cumulative sum in reverse. |
| | inclusive (bool): The i-th element of the output includes the i-th |
| | element of the input. |
| | |
| | Returns: |
| | array: The output array. |
| | )pbdoc"); |
| | m.def( |
| | "cumprod", |
| | [](const mx::array& a, |
| | std::optional<int> axis, |
| | bool reverse, |
| | bool inclusive, |
| | mx::StreamOrDevice s) { |
| | if (axis) { |
| | return mx::cumprod(a, *axis, reverse, inclusive, s); |
| | } else { |
| | return mx::cumprod(mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); |
| | } |
| | }, |
| | nb::arg(), |
| | "axis"_a = nb::none(), |
| | nb::kw_only(), |
| | "reverse"_a = false, |
| | "inclusive"_a = true, |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def cumprod(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Return the cumulative product of the elements along the given axis. |
| | |
| | Args: |
| | a (array): Input array |
| | axis (int, optional): Optional axis to compute the cumulative product |
| | over. If unspecified the cumulative product of the flattened array is |
| | returned. |
| | reverse (bool): Perform the cumulative product in reverse. |
| | inclusive (bool): The i-th element of the output includes the i-th |
| | element of the input. |
| | |
| | Returns: |
| | array: The output array. |
| | )pbdoc"); |
| | m.def( |
| | "cummax", |
| | [](const mx::array& a, |
| | std::optional<int> axis, |
| | bool reverse, |
| | bool inclusive, |
| | mx::StreamOrDevice s) { |
| | if (axis) { |
| | return mx::cummax(a, *axis, reverse, inclusive, s); |
| | } else { |
| | return mx::cummax(mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); |
| | } |
| | }, |
| | nb::arg(), |
| | "axis"_a = nb::none(), |
| | nb::kw_only(), |
| | "reverse"_a = false, |
| | "inclusive"_a = true, |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def cummax(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Return the cumulative maximum of the elements along the given axis. |
| | |
| | Args: |
| | a (array): Input array |
| | axis (int, optional): Optional axis to compute the cumulative maximum |
| | over. If unspecified the cumulative maximum of the flattened array is |
| | returned. |
| | reverse (bool): Perform the cumulative maximum in reverse. |
| | inclusive (bool): The i-th element of the output includes the i-th |
| | element of the input. |
| | |
| | Returns: |
| | array: The output array. |
| | )pbdoc"); |
| | m.def( |
| | "cummin", |
| | [](const mx::array& a, |
| | std::optional<int> axis, |
| | bool reverse, |
| | bool inclusive, |
| | mx::StreamOrDevice s) { |
| | if (axis) { |
| | return mx::cummin(a, *axis, reverse, inclusive, s); |
| | } else { |
| | return mx::cummin(mx::reshape(a, {-1}, s), 0, reverse, inclusive, s); |
| | } |
| | }, |
| | nb::arg(), |
| | "axis"_a = nb::none(), |
| | nb::kw_only(), |
| | "reverse"_a = false, |
| | "inclusive"_a = true, |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def cummin(a: array, /, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Return the cumulative minimum of the elements along the given axis. |
| | |
| | Args: |
| | a (array): Input array |
| | axis (int, optional): Optional axis to compute the cumulative minimum |
| | over. If unspecified the cumulative minimum of the flattened array is |
| | returned. |
| | reverse (bool): Perform the cumulative minimum in reverse. |
| | inclusive (bool): The i-th element of the output includes the i-th |
| | element of the input. |
| | |
| | Returns: |
| | array: The output array. |
| | )pbdoc"); |
| | m.def( |
| | "conj", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::conjugate(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def conj(a: array, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Return the elementwise complex conjugate of the input. |
| | Alias for `mx.conjugate`. |
| | |
| | Args: |
| | a (array): Input array |
| | |
| | Returns: |
| | array: The output array. |
| | )pbdoc"); |
| | m.def( |
| | "conjugate", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::conjugate(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def conjugate(a: array, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Return the elementwise complex conjugate of the input. |
| | Alias for `mx.conj`. |
| | |
| | Args: |
| | a (array): Input array |
| | |
| | Returns: |
| | array: The output array. |
| | )pbdoc"); |
| | m.def( |
| | "convolve", |
| | [](const mx::array& a, |
| | const mx::array& v, |
| | const std::string& mode, |
| | mx::StreamOrDevice s) { |
| | if (a.ndim() != 1 || v.ndim() != 1) { |
| | throw std::invalid_argument("[convolve] Inputs must be 1D."); |
| | } |
| |
|
| | if (a.size() == 0 || v.size() == 0) { |
| | throw std::invalid_argument("[convolve] Inputs cannot be empty."); |
| | } |
| |
|
| | mx::array in = a.size() < v.size() ? v : a; |
| | mx::array wt = a.size() < v.size() ? a : v; |
| | wt = mx::slice(wt, {wt.shape(0) - 1}, {-wt.shape(0) - 1}, {-1}, s); |
| |
|
| | in = mx::reshape(in, {1, -1, 1}, s); |
| | wt = mx::reshape(wt, {1, -1, 1}, s); |
| |
|
| | int padding = 0; |
| |
|
| | if (mode == "full") { |
| | padding = wt.size() - 1; |
| | } else if (mode == "valid") { |
| | padding = 0; |
| | } else if (mode == "same") { |
| | |
| | if (wt.size() % 2) { |
| | padding = wt.size() / 2; |
| | } else { |
| | int pad_l = wt.size() / 2; |
| | int pad_r = std::max(0, pad_l - 1); |
| | in = mx::pad( |
| | in, |
| | {{0, 0}, {pad_l, pad_r}, {0, 0}}, |
| | mx::array(0), |
| | "constant", |
| | s); |
| | } |
| |
|
| | } else { |
| | throw std::invalid_argument("[convolve] Invalid mode."); |
| | } |
| |
|
| | mx::array out = mx::conv1d( |
| | in, |
| | wt, |
| | 1, |
| | padding, |
| | 1, |
| | 1, |
| | s); |
| |
|
| | return mx::reshape(out, {-1}, s); |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | "mode"_a = "full", |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | R"(def convolve(a: array, v: array, /, mode: str = "full", *, stream: Union[None, Stream, Device] = None) -> array)"), |
| | R"pbdoc( |
| | The discrete convolution of 1D arrays. |
| | |
| | If ``v`` is longer than ``a``, then they are swapped. |
| | The conv filter is flipped following signal processing convention. |
| | |
| | Args: |
| | a (array): 1D Input array. |
| | v (array): 1D Input array. |
| | mode (str, optional): {'full', 'valid', 'same'} |
| | |
| | Returns: |
| | array: The convolved array. |
| | )pbdoc"); |
| | m.def( |
| | "conv1d", |
| | &mx::conv1d, |
| | nb::arg(), |
| | nb::arg(), |
| | "stride"_a = 1, |
| | "padding"_a = 0, |
| | "dilation"_a = 1, |
| | "groups"_a = 1, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def conv1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | 1D convolution over an input with several channels |
| | |
| | Args: |
| | input (array): Input array of shape ``(N, L, C_in)``. |
| | weight (array): Weight array of shape ``(C_out, K, C_in)``. |
| | stride (int, optional): Kernel stride. Default: ``1``. |
| | padding (int, optional): Input padding. Default: ``0``. |
| | dilation (int, optional): Kernel dilation. Default: ``1``. |
| | groups (int, optional): Input feature groups. Default: ``1``. |
| | |
| | Returns: |
| | array: The convolved array. |
| | )pbdoc"); |
| | m.def( |
| | "conv2d", |
| | [](const mx::array& input, |
| | const mx::array& weight, |
| | const std::variant<int, std::pair<int, int>>& stride, |
| | const std::variant<int, std::pair<int, int>>& padding, |
| | const std::variant<int, std::pair<int, int>>& dilation, |
| | int groups, |
| | mx::StreamOrDevice s) { |
| | std::pair<int, int> stride_pair{1, 1}; |
| | std::pair<int, int> padding_pair{0, 0}; |
| | std::pair<int, int> dilation_pair{1, 1}; |
| |
|
| | if (auto pv = std::get_if<int>(&stride); pv) { |
| | stride_pair = std::pair<int, int>{*pv, *pv}; |
| | } else { |
| | stride_pair = std::get<std::pair<int, int>>(stride); |
| | } |
| |
|
| | if (auto pv = std::get_if<int>(&padding); pv) { |
| | padding_pair = std::pair<int, int>{*pv, *pv}; |
| | } else { |
| | padding_pair = std::get<std::pair<int, int>>(padding); |
| | } |
| |
|
| | if (auto pv = std::get_if<int>(&dilation); pv) { |
| | dilation_pair = std::pair<int, int>{*pv, *pv}; |
| | } else { |
| | dilation_pair = std::get<std::pair<int, int>>(dilation); |
| | } |
| |
|
| | return mx::conv2d( |
| | input, weight, stride_pair, padding_pair, dilation_pair, groups, s); |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | "stride"_a = 1, |
| | "padding"_a = 0, |
| | "dilation"_a = 1, |
| | "groups"_a = 1, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def conv2d(input: array, weight: array, /, stride: Union[int, tuple[int, int]] = 1, padding: Union[int, tuple[int, int]] = 0, dilation: Union[int, tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | 2D convolution over an input with several channels |
| | |
| | Args: |
| | input (array): Input array of shape ``(N, H, W, C_in)``. |
| | weight (array): Weight array of shape ``(C_out, KH, KW, C_in)``. |
| | stride (int or tuple(int), optional): :obj:`tuple` of size 2 with |
| | kernel strides. All spatial dimensions get the same stride if |
| | only one number is specified. Default: ``1``. |
| | padding (int or tuple(int), optional): :obj:`tuple` of size 2 with |
| | symmetric input padding. All spatial dimensions get the same |
| | padding if only one number is specified. Default: ``0``. |
| | dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with |
| | kernel dilation. All spatial dimensions get the same dilation |
| | if only one number is specified. Default: ``1`` |
| | groups (int, optional): input feature groups. Default: ``1``. |
| | |
| | Returns: |
| | array: The convolved array. |
| | )pbdoc"); |
| | m.def( |
| | "conv3d", |
| | [](const mx::array& input, |
| | const mx::array& weight, |
| | const std::variant<int, std::tuple<int, int, int>>& stride, |
| | const std::variant<int, std::tuple<int, int, int>>& padding, |
| | const std::variant<int, std::tuple<int, int, int>>& dilation, |
| | int groups, |
| | mx::StreamOrDevice s) { |
| | std::tuple<int, int, int> stride_tuple{1, 1, 1}; |
| | std::tuple<int, int, int> padding_tuple{0, 0, 0}; |
| | std::tuple<int, int, int> dilation_tuple{1, 1, 1}; |
| |
|
| | if (auto pv = std::get_if<int>(&stride); pv) { |
| | stride_tuple = std::tuple<int, int, int>{*pv, *pv, *pv}; |
| | } else { |
| | stride_tuple = std::get<std::tuple<int, int, int>>(stride); |
| | } |
| |
|
| | if (auto pv = std::get_if<int>(&padding); pv) { |
| | padding_tuple = std::tuple<int, int, int>{*pv, *pv, *pv}; |
| | } else { |
| | padding_tuple = std::get<std::tuple<int, int, int>>(padding); |
| | } |
| |
|
| | if (auto pv = std::get_if<int>(&dilation); pv) { |
| | dilation_tuple = std::tuple<int, int, int>{*pv, *pv, *pv}; |
| | } else { |
| | dilation_tuple = std::get<std::tuple<int, int, int>>(dilation); |
| | } |
| |
|
| | return mx::conv3d( |
| | input, |
| | weight, |
| | stride_tuple, |
| | padding_tuple, |
| | dilation_tuple, |
| | groups, |
| | s); |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | "stride"_a = 1, |
| | "padding"_a = 0, |
| | "dilation"_a = 1, |
| | "groups"_a = 1, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def conv3d(input: array, weight: array, /, stride: Union[int, tuple[int, int, int]] = 1, padding: Union[int, tuple[int, int, int]] = 0, dilation: Union[int, tuple[int, int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | 3D convolution over an input with several channels |
| | |
| | Note: Only the default ``groups=1`` is currently supported. |
| | |
| | Args: |
| | input (array): Input array of shape ``(N, D, H, W, C_in)``. |
| | weight (array): Weight array of shape ``(C_out, KD, KH, KW, C_in)``. |
| | stride (int or tuple(int), optional): :obj:`tuple` of size 3 with |
| | kernel strides. All spatial dimensions get the same stride if |
| | only one number is specified. Default: ``1``. |
| | padding (int or tuple(int), optional): :obj:`tuple` of size 3 with |
| | symmetric input padding. All spatial dimensions get the same |
| | padding if only one number is specified. Default: ``0``. |
| | dilation (int or tuple(int), optional): :obj:`tuple` of size 3 with |
| | kernel dilation. All spatial dimensions get the same dilation |
| | if only one number is specified. Default: ``1`` |
| | groups (int, optional): input feature groups. Default: ``1``. |
| | |
| | Returns: |
| | array: The convolved array. |
| | )pbdoc"); |
| | m.def( |
| | "conv_transpose1d", |
| | &mx::conv_transpose1d, |
| | nb::arg(), |
| | nb::arg(), |
| | "stride"_a = 1, |
| | "padding"_a = 0, |
| | "dilation"_a = 1, |
| | "output_padding"_a = 0, |
| | "groups"_a = 1, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def conv_transpose1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, output_padding: int = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | 1D transposed convolution over an input with several channels |
| | |
| | Args: |
| | input (array): Input array of shape ``(N, L, C_in)``. |
| | weight (array): Weight array of shape ``(C_out, K, C_in)``. |
| | stride (int, optional): Kernel stride. Default: ``1``. |
| | padding (int, optional): Input padding. Default: ``0``. |
| | dilation (int, optional): Kernel dilation. Default: ``1``. |
| | output_padding (int, optional): Output padding. Default: ``0``. |
| | groups (int, optional): Input feature groups. Default: ``1``. |
| | |
| | Returns: |
| | array: The convolved array. |
| | )pbdoc"); |
| | m.def( |
| | "conv_transpose2d", |
| | [](const mx::array& input, |
| | const mx::array& weight, |
| | const std::variant<int, std::pair<int, int>>& stride, |
| | const std::variant<int, std::pair<int, int>>& padding, |
| | const std::variant<int, std::pair<int, int>>& dilation, |
| | const std::variant<int, std::pair<int, int>>& output_padding, |
| | int groups, |
| | mx::StreamOrDevice s) { |
| | std::pair<int, int> stride_pair{1, 1}; |
| | std::pair<int, int> padding_pair{0, 0}; |
| | std::pair<int, int> dilation_pair{1, 1}; |
| | std::pair<int, int> output_padding_pair{0, 0}; |
| |
|
| | if (auto pv = std::get_if<int>(&stride); pv) { |
| | stride_pair = std::pair<int, int>{*pv, *pv}; |
| | } else { |
| | stride_pair = std::get<std::pair<int, int>>(stride); |
| | } |
| |
|
| | if (auto pv = std::get_if<int>(&padding); pv) { |
| | padding_pair = std::pair<int, int>{*pv, *pv}; |
| | } else { |
| | padding_pair = std::get<std::pair<int, int>>(padding); |
| | } |
| |
|
| | if (auto pv = std::get_if<int>(&dilation); pv) { |
| | dilation_pair = std::pair<int, int>{*pv, *pv}; |
| | } else { |
| | dilation_pair = std::get<std::pair<int, int>>(dilation); |
| | } |
| |
|
| | if (auto pv = std::get_if<int>(&output_padding); pv) { |
| | output_padding_pair = std::pair<int, int>{*pv, *pv}; |
| | } else { |
| | output_padding_pair = std::get<std::pair<int, int>>(output_padding); |
| | } |
| |
|
| | return mx::conv_transpose2d( |
| | input, |
| | weight, |
| | stride_pair, |
| | padding_pair, |
| | dilation_pair, |
| | output_padding_pair, |
| | groups, |
| | s); |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | "stride"_a = 1, |
| | "padding"_a = 0, |
| | "dilation"_a = 1, |
| | "output_padding"_a = 0, |
| | "groups"_a = 1, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def conv_transpose2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, output_padding: Union[int, Tuple[int, int]] = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | 2D transposed convolution over an input with several channels |
| | |
| | Note: Only the default ``groups=1`` is currently supported. |
| | |
| | Args: |
| | input (array): Input array of shape ``(N, H, W, C_in)``. |
| | weight (array): Weight array of shape ``(C_out, KH, KW, C_in)``. |
| | stride (int or tuple(int), optional): :obj:`tuple` of size 2 with |
| | kernel strides. All spatial dimensions get the same stride if |
| | only one number is specified. Default: ``1``. |
| | padding (int or tuple(int), optional): :obj:`tuple` of size 2 with |
| | symmetric input padding. All spatial dimensions get the same |
| | padding if only one number is specified. Default: ``0``. |
| | dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with |
| | kernel dilation. All spatial dimensions get the same dilation |
| | if only one number is specified. Default: ``1`` |
| | output_padding (int or tuple(int), optional): :obj:`tuple` of size 2 with |
| | output padding. All spatial dimensions get the same output |
| | padding if only one number is specified. Default: ``0``. |
| | groups (int, optional): input feature groups. Default: ``1``. |
| | |
| | Returns: |
| | array: The convolved array. |
| | )pbdoc"); |
| | m.def( |
| | "conv_transpose3d", |
| | [](const mx::array& input, |
| | const mx::array& weight, |
| | const std::variant<int, std::tuple<int, int, int>>& stride, |
| | const std::variant<int, std::tuple<int, int, int>>& padding, |
| | const std::variant<int, std::tuple<int, int, int>>& dilation, |
| | const std::variant<int, std::tuple<int, int, int>>& output_padding, |
| | int groups, |
| | mx::StreamOrDevice s) { |
| | std::tuple<int, int, int> stride_tuple{1, 1, 1}; |
| | std::tuple<int, int, int> padding_tuple{0, 0, 0}; |
| | std::tuple<int, int, int> dilation_tuple{1, 1, 1}; |
| | std::tuple<int, int, int> output_padding_tuple{0, 0, 0}; |
| |
|
| | if (auto pv = std::get_if<int>(&stride); pv) { |
| | stride_tuple = std::tuple<int, int, int>{*pv, *pv, *pv}; |
| | } else { |
| | stride_tuple = std::get<std::tuple<int, int, int>>(stride); |
| | } |
| |
|
| | if (auto pv = std::get_if<int>(&padding); pv) { |
| | padding_tuple = std::tuple<int, int, int>{*pv, *pv, *pv}; |
| | } else { |
| | padding_tuple = std::get<std::tuple<int, int, int>>(padding); |
| | } |
| |
|
| | if (auto pv = std::get_if<int>(&dilation); pv) { |
| | dilation_tuple = std::tuple<int, int, int>{*pv, *pv, *pv}; |
| | } else { |
| | dilation_tuple = std::get<std::tuple<int, int, int>>(dilation); |
| | } |
| |
|
| | if (auto pv = std::get_if<int>(&output_padding); pv) { |
| | output_padding_tuple = std::tuple<int, int, int>{*pv, *pv, *pv}; |
| | } else { |
| | output_padding_tuple = |
| | std::get<std::tuple<int, int, int>>(output_padding); |
| | } |
| |
|
| | return mx::conv_transpose3d( |
| | input, |
| | weight, |
| | stride_tuple, |
| | padding_tuple, |
| | dilation_tuple, |
| | output_padding_tuple, |
| | groups, |
| | s); |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | "stride"_a = 1, |
| | "padding"_a = 0, |
| | "dilation"_a = 1, |
| | "output_padding"_a = 0, |
| | "groups"_a = 1, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def conv_transpose3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, output_padding: Union[int, Tuple[int, int, int]] = 0, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | 3D transposed convolution over an input with several channels |
| | |
| | Note: Only the default ``groups=1`` is currently supported. |
| | |
| | Args: |
| | input (array): Input array of shape ``(N, D, H, W, C_in)``. |
| | weight (array): Weight array of shape ``(C_out, KD, KH, KW, C_in)``. |
| | stride (int or tuple(int), optional): :obj:`tuple` of size 3 with |
| | kernel strides. All spatial dimensions get the same stride if |
| | only one number is specified. Default: ``1``. |
| | padding (int or tuple(int), optional): :obj:`tuple` of size 3 with |
| | symmetric input padding. All spatial dimensions get the same |
| | padding if only one number is specified. Default: ``0``. |
| | dilation (int or tuple(int), optional): :obj:`tuple` of size 3 with |
| | kernel dilation. All spatial dimensions get the same dilation |
| | if only one number is specified. Default: ``1`` |
| | output_padding (int or tuple(int), optional): :obj:`tuple` of size 3 with |
| | output padding. All spatial dimensions get the same output |
| | padding if only one number is specified. Default: ``0``. |
| | groups (int, optional): input feature groups. Default: ``1``. |
| | |
| | Returns: |
| | array: The convolved array. |
| | )pbdoc"); |
| | m.def( |
| | "conv_general", |
| | [](const mx::array& input, |
| | const mx::array& weight, |
| | const std::variant<int, std::vector<int>>& stride, |
| | const std::variant< |
| | int, |
| | std::vector<int>, |
| | std::pair<std::vector<int>, std::vector<int>>>& padding, |
| | const std::variant<int, std::vector<int>>& kernel_dilation, |
| | const std::variant<int, std::vector<int>>& input_dilation, |
| | int groups, |
| | bool flip, |
| | mx::StreamOrDevice s) { |
| | std::vector<int> stride_vec; |
| | std::vector<int> padding_lo_vec; |
| | std::vector<int> padding_hi_vec; |
| | std::vector<int> kernel_dilation_vec; |
| | std::vector<int> input_dilation_vec; |
| |
|
| | if (auto pv = std::get_if<int>(&stride); pv) { |
| | stride_vec.push_back(*pv); |
| | } else { |
| | stride_vec = std::get<std::vector<int>>(stride); |
| | } |
| |
|
| | if (auto pv = std::get_if<int>(&padding); pv) { |
| | padding_lo_vec.push_back(*pv); |
| | padding_hi_vec.push_back(*pv); |
| | } else if (auto pv = std::get_if<std::vector<int>>(&padding); pv) { |
| | padding_lo_vec = *pv; |
| | padding_hi_vec = *pv; |
| | } else { |
| | auto [pl, ph] = |
| | std::get<std::pair<std::vector<int>, std::vector<int>>>(padding); |
| | padding_lo_vec = pl; |
| | padding_hi_vec = ph; |
| | } |
| |
|
| | if (auto pv = std::get_if<int>(&kernel_dilation); pv) { |
| | kernel_dilation_vec.push_back(*pv); |
| | } else { |
| | kernel_dilation_vec = std::get<std::vector<int>>(kernel_dilation); |
| | } |
| |
|
| | if (auto pv = std::get_if<int>(&input_dilation); pv) { |
| | input_dilation_vec.push_back(*pv); |
| | } else { |
| | input_dilation_vec = std::get<std::vector<int>>(input_dilation); |
| | } |
| |
|
| | return mx::conv_general( |
| | std::move(input), |
| | std::move(weight), |
| | std::move(stride_vec), |
| | std::move(padding_lo_vec), |
| | std::move(padding_hi_vec), |
| | |
| | std::move(kernel_dilation_vec), |
| | |
| | std::move(input_dilation_vec), |
| | groups, |
| | flip, |
| | s); |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | "stride"_a = 1, |
| | "padding"_a = 0, |
| | "kernel_dilation"_a = 1, |
| | "input_dilation"_a = 1, |
| | "groups"_a = 1, |
| | "flip"_a = false, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def conv_general(input: array, weight: array, /, stride: Union[int, Sequence[int]] = 1, padding: Union[int, Sequence[int], tuple[Sequence[int], Sequence[int]]] = 0, kernel_dilation: Union[int, Sequence[int]] = 1, input_dilation: Union[int, Sequence[int]] = 1, groups: int = 1, flip: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | General convolution over an input with several channels |
| | |
| | Args: |
| | input (array): Input array of shape ``(N, ..., C_in)``. |
| | weight (array): Weight array of shape ``(C_out, ..., C_in)``. |
| | stride (int or list(int), optional): :obj:`list` with kernel strides. |
| | All spatial dimensions get the same stride if |
| | only one number is specified. Default: ``1``. |
| | padding (int, list(int), or tuple(list(int), list(int)), optional): |
| | :obj:`list` with input padding. All spatial dimensions get the same |
| | padding if only one number is specified. Default: ``0``. |
| | kernel_dilation (int or list(int), optional): :obj:`list` with |
| | kernel dilation. All spatial dimensions get the same dilation |
| | if only one number is specified. Default: ``1`` |
| | input_dilation (int or list(int), optional): :obj:`list` with |
| | input dilation. All spatial dimensions get the same dilation |
| | if only one number is specified. Default: ``1`` |
| | groups (int, optional): Input feature groups. Default: ``1``. |
| | flip (bool, optional): Flip the order in which the spatial dimensions of |
| | the weights are processed. Performs the cross-correlation operator when |
| | ``flip`` is ``False`` and the convolution operator otherwise. |
| | Default: ``False``. |
| | |
| | Returns: |
| | array: The convolved array. |
| | )pbdoc"); |
| | m.def( |
| | "save", |
| | &mlx_save_helper, |
| | "file"_a, |
| | "arr"_a, |
| | nb::sig( |
| | "def save(file: Union[file, str, pathlib.Path], arr: array) -> None"), |
| | R"pbdoc( |
| | Save the array to a binary file in ``.npy`` format. |
| | |
| | Args: |
| | file (str, pathlib.Path, file): File to which the array is saved |
| | arr (array): Array to be saved. |
| | )pbdoc"); |
| | m.def( |
| | "savez", |
| | [](nb::object file, nb::args args, const nb::kwargs& kwargs) { |
| | mlx_savez_helper(file, args, kwargs, false); |
| | }, |
| | "file"_a, |
| | "args"_a, |
| | "kwargs"_a, |
| | nb::sig( |
| | "def savez(file: Union[file, str, pathlib.Path], *args, **kwargs)"), |
| | R"pbdoc( |
| | Save several arrays to a binary file in uncompressed ``.npz`` |
| | format. |
| | |
| | .. code-block:: python |
| | |
| | import mlx.core as mx |
| | |
| | x = mx.ones((10, 10)) |
| | mx.savez("my_path.npz", x=x) |
| | |
| | import mlx.nn as nn |
| | from mlx.utils import tree_flatten |
| | |
| | model = nn.TransformerEncoder(6, 128, 4) |
| | flat_params = tree_flatten(model.parameters()) |
| | mx.savez("model.npz", **dict(flat_params)) |
| | |
| | Args: |
| | file (file, str, pathlib.Path): Path to file to which the arrays are saved. |
| | *args (arrays): Arrays to be saved. |
| | **kwargs (arrays): Arrays to be saved. Each array will be saved |
| | with the associated keyword as the output file name. |
| | )pbdoc"); |
| | m.def( |
| | "savez_compressed", |
| | [](nb::object file, nb::args args, const nb::kwargs& kwargs) { |
| | mlx_savez_helper(file, args, kwargs, true); |
| | }, |
| | nb::arg(), |
| | "args"_a, |
| | "kwargs"_a, |
| | nb::sig( |
| | "def savez_compressed(file: Union[file, str, pathlib.Path], *args, **kwargs)"), |
| | R"pbdoc( |
| | Save several arrays to a binary file in compressed ``.npz`` format. |
| | |
| | Args: |
| | file (file, str, pathlib.Path): Path to file to which the arrays are saved. |
| | *args (arrays): Arrays to be saved. |
| | **kwargs (arrays): Arrays to be saved. Each array will be saved |
| | with the associated keyword as the output file name. |
| | )pbdoc"); |
| | m.def( |
| | "load", |
| | &mlx_load_helper, |
| | nb::arg(), |
| | "format"_a = nb::none(), |
| | "return_metadata"_a = false, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def load(file: Union[file, str, pathlib.Path], /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, dict[str, array]]"), |
| | R"pbdoc( |
| | Load array(s) from a binary file. |
| | |
| | The supported formats are ``.npy``, ``.npz``, ``.safetensors``, and |
| | ``.gguf``. |
| | |
| | Args: |
| | file (file, str, pathlib.Path): File in which the array is saved. |
| | format (str, optional): Format of the file. If ``None``, the |
| | format is inferred from the file extension. Supported formats: |
| | ``npy``, ``npz``, and ``safetensors``. Default: ``None``. |
| | return_metadata (bool, optional): Load the metadata for formats |
| | which support matadata. The metadata will be returned as an |
| | additional dictionary. Default: ``False``. |
| | Returns: |
| | array or dict: |
| | A single array if loading from a ``.npy`` file or a dict |
| | mapping names to arrays if loading from a ``.npz`` or |
| | ``.safetensors`` file. If ``return_metadata`` is ``True`` an |
| | additional dictionary of metadata will be returned. |
| | |
| | Warning: |
| | |
| | When loading unsupported quantization formats from GGUF, tensors |
| | will automatically cast to ``mx.float16`` |
| | )pbdoc"); |
| | m.def( |
| | "save_safetensors", |
| | &mlx_save_safetensor_helper, |
| | "file"_a, |
| | "arrays"_a, |
| | "metadata"_a = nb::none(), |
| | nb::sig( |
| | "def save_safetensors(file: Union[file, str, pathlib.Path], arrays: dict[str, array], metadata: Optional[dict[str, str]] = None)"), |
| | R"pbdoc( |
| | Save array(s) to a binary file in ``.safetensors`` format. |
| | |
| | See the `Safetensors documentation |
| | <https://huggingface.co/docs/safetensors/index>`_ for more |
| | information on the format. |
| | |
| | Args: |
| | file (file, str, pathlib.Path): File in which the array is saved. |
| | arrays (dict(str, array)): The dictionary of names to arrays to |
| | be saved. |
| | metadata (dict(str, str), optional): The dictionary of |
| | metadata to be saved. |
| | )pbdoc"); |
| | m.def( |
| | "save_gguf", |
| | &mlx_save_gguf_helper, |
| | "file"_a, |
| | "arrays"_a, |
| | "metadata"_a = nb::none(), |
| | nb::sig( |
| | "def save_gguf(file: Union[file, str, pathlib.Path], arrays: dict[str, array], metadata: dict[str, Union[array, str, list[str]]])"), |
| | R"pbdoc( |
| | Save array(s) to a binary file in ``.gguf`` format. |
| | |
| | See the `GGUF documentation |
| | <https://github.com/ggerganov/ggml/blob/master/docs/gguf.md>`_ for |
| | more information on the format. |
| | |
| | Args: |
| | file (file, str, pathlib.Path): File in which the array is saved. |
| | arrays (dict(str, array)): The dictionary of names to arrays to |
| | be saved. |
| | metadata (dict(str, Union[array, str, list(str)])): The dictionary |
| | of metadata to be saved. The values can be a scalar or 1D |
| | obj:`array`, a :obj:`str`, or a :obj:`list` of :obj:`str`. |
| | )pbdoc"); |
| | m.def( |
| | "where", |
| | [](const ScalarOrArray& condition, |
| | const ScalarOrArray& x_, |
| | const ScalarOrArray& y_, |
| | mx::StreamOrDevice s) { |
| | auto [x, y] = to_arrays(x_, y_); |
| | return mx::where(to_array(condition), x, y, s); |
| | }, |
| | "condition"_a, |
| | nb::arg(), |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def where(condition: Union[scalar, array], x: Union[scalar, array], y: Union[scalar, array], /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Select from ``x`` or ``y`` according to ``condition``. |
| | |
| | The condition and input arrays must be the same shape or |
| | broadcastable with each another. |
| | |
| | Args: |
| | condition (array): The condition array. |
| | x (array): The input selected from where condition is ``True``. |
| | y (array): The input selected from where condition is ``False``. |
| | |
| | Returns: |
| | array: The output containing elements selected from |
| | ``x`` and ``y``. |
| | )pbdoc"); |
| | m.def( |
| | "nan_to_num", |
| | [](const ScalarOrArray& a, |
| | float nan, |
| | std::optional<float>& posinf, |
| | std::optional<float>& neginf, |
| | mx::StreamOrDevice s) { |
| | return mx::nan_to_num(to_array(a), nan, posinf, neginf, s); |
| | }, |
| | nb::arg(), |
| | "nan"_a = 0.0f, |
| | "posinf"_a = nb::none(), |
| | "neginf"_a = nb::none(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def nan_to_num(a: Union[scalar, array], nan: float = 0, posinf: Optional[float] = None, neginf: Optional[float] = None, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Replace NaN and Inf values with finite numbers. |
| | |
| | Args: |
| | a (array): Input array |
| | nan (float, optional): Value to replace NaN with. Default: ``0``. |
| | posinf (float, optional): Value to replace positive infinities |
| | with. If ``None``, defaults to largest finite value for the |
| | given data type. Default: ``None``. |
| | neginf (float, optional): Value to replace negative infinities |
| | with. If ``None``, defaults to the negative of the largest |
| | finite value for the given data type. Default: ``None``. |
| | |
| | Returns: |
| | array: Output array with NaN and Inf replaced. |
| | )pbdoc"); |
| | m.def( |
| | "round", |
| | [](const ScalarOrArray& a, int decimals, mx::StreamOrDevice s) { |
| | return mx::round(to_array(a), decimals, s); |
| | }, |
| | nb::arg(), |
| | "decimals"_a = 0, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def round(a: array, /, decimals: int = 0, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Round to the given number of decimals. |
| | |
| | Basically performs: |
| | |
| | .. code-block:: python |
| | |
| | s = 10**decimals |
| | x = round(x * s) / s |
| | |
| | Args: |
| | a (array): Input array |
| | decimals (int): Number of decimal places to round to. (default: 0) |
| | |
| | Returns: |
| | array: An array of the same type as ``a`` rounded to the |
| | given number of decimals. |
| | )pbdoc"); |
| | m.def( |
| | "quantized_matmul", |
| | &mx::quantized_matmul, |
| | nb::arg(), |
| | nb::arg(), |
| | "scales"_a, |
| | "biases"_a = nb::none(), |
| | "transpose"_a = true, |
| | "group_size"_a = 64, |
| | "bits"_a = 4, |
| | "mode"_a = "affine", |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Perform the matrix multiplication with the quantized matrix ``w``. The |
| | quantization uses one floating point scale and bias per ``group_size`` of |
| | elements. Each element in ``w`` takes ``bits`` bits and is packed in an |
| | unsigned 32 bit integer. |
| | |
| | Args: |
| | x (array): Input array |
| | w (array): Quantized matrix packed in unsigned integers |
| | scales (array): The scales to use per ``group_size`` elements of ``w`` |
| | biases (array, optional): The biases to use per ``group_size`` |
| | elements of ``w``. Default: ``None``. |
| | transpose (bool, optional): Defines whether to multiply with the |
| | transposed ``w`` or not, namely whether we are performing |
| | ``x @ w.T`` or ``x @ w``. Default: ``True``. |
| | group_size (int, optional): The size of the group in ``w`` that |
| | shares a scale and bias. Default: ``64``. |
| | bits (int, optional): The number of bits occupied by each element in |
| | ``w``. Default: ``4``. |
| | mode (str, optional): The quantization mode. Default: ``"affine"``. |
| | |
| | Returns: |
| | array: The result of the multiplication of ``x`` with ``w``. |
| | )pbdoc"); |
| | m.def( |
| | "quantize", |
| | &mx::quantize, |
| | nb::arg(), |
| | "group_size"_a = 64, |
| | "bits"_a = 4, |
| | "mode"_a = "affine", |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def quantize(w: array, /, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"), |
| | R"pbdoc( |
| | Quantize the matrix ``w`` using ``bits`` bits per element. |
| | |
| | Note, every ``group_size`` elements in a row of ``w`` are quantized |
| | together. Hence, number of columns of ``w`` should be divisible by |
| | ``group_size``. In particular, the rows of ``w`` are divided into groups of |
| | size ``group_size`` which are quantized together. |
| | |
| | .. warning:: |
| | |
| | ``quantize`` currently only supports 2D inputs with the second |
| | dimension divisible by ``group_size`` |
| | |
| | The supported quantization modes are ``"affine"`` and ``"mxfp4"``. They |
| | are described in more detail below. |
| | |
| | Args: |
| | w (array): Matrix to be quantized |
| | group_size (int, optional): The size of the group in ``w`` that shares a |
| | scale and bias. Default: ``64``. |
| | bits (int, optional): The number of bits occupied by each element of |
| | ``w`` in the returned quantized matrix. Default: ``4``. |
| | mode (str, optional): The quantization mode. Default: ``"affine"``. |
| | |
| | Returns: |
| | tuple: A tuple with either two or three elements containing: |
| | |
| | * w_q (array): The quantized version of ``w`` |
| | * scales (array): The quantization scales |
| | * biases (array): The quantization biases (returned for ``mode=="affine"``). |
| | |
| | Notes: |
| | The ``affine`` mode quantizes groups of :math:`g` consecutive |
| | elements in a row of ``w``. For each group the quantized |
| | representation of each element :math:`\hat{w_i}` is computed as follows: |
| | |
| | .. math:: |
| | |
| | \begin{aligned} |
| | \alpha &= \max_i w_i \\ |
| | \beta &= \min_i w_i \\ |
| | s &= \frac{\alpha - \beta}{2^b - 1} \\ |
| | \hat{w_i} &= \textrm{round}\left( \frac{w_i - \beta}{s}\right). |
| | \end{aligned} |
| | |
| | After the above computation, :math:`\hat{w_i}` fits in :math:`b` bits |
| | and is packed in an unsigned 32-bit integer from the lower to upper |
| | bits. For instance, for 4-bit quantization we fit 8 elements in an |
| | unsigned 32 bit integer where the 1st element occupies the 4 least |
| | significant bits, the 2nd bits 4-7 etc. |
| | |
| | To dequantize the elements of ``w``, we also save :math:`s` and |
| | :math:`\beta` which are the returned ``scales`` and |
| | ``biases`` respectively. |
| | |
| | The ``mxfp4`` mode similarly quantizes groups of :math:`g` elements |
| | of ``w``. For ``mxfp4`` the group size must be ``32``. The elements |
| | are quantized to 4-bit precision floating-point values (E2M1) with a |
| | shared 8-bit scale per group. Unlike ``affine`` quantization, |
| | ``mxfp4`` does not have a bias value. More details on the format can |
| | be found in the `specification <https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf>`_. |
| | )pbdoc"); |
| | m.def( |
| | "dequantize", |
| | &mx::dequantize, |
| | nb::arg(), |
| | "scales"_a, |
| | "biases"_a = nb::none(), |
| | "group_size"_a = 64, |
| | "bits"_a = 4, |
| | "mode"_a = "affine", |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Dequantize the matrix ``w`` using quantization parameters. |
| | |
| | Args: |
| | w (array): Matrix to be dequantized |
| | scales (array): The scales to use per ``group_size`` elements of ``w``. |
| | biases (array, optional): The biases to use per ``group_size`` |
| | elements of ``w``. Default: ``None``. |
| | group_size (int, optional): The size of the group in ``w`` that shares a |
| | scale and bias. Default: ``64``. |
| | bits (int, optional): The number of bits occupied by each element in |
| | ``w``. Default: ``4``. |
| | mode (str, optional): The quantization mode. Default: ``"affine"``. |
| | |
| | Returns: |
| | array: The dequantized version of ``w`` |
| | |
| | Notes: |
| | The currently supported quantization modes are ``"affine"`` and ``mxfp4``. |
| | |
| | For ``affine`` quantization, given the notation in :func:`quantize`, |
| | we compute :math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s` |
| | and :math:`\beta` as follows |
| | |
| | .. math:: |
| | |
| | w_i = s \hat{w_i} + \beta |
| | )pbdoc"); |
| | m.def( |
| | "gather_qmm", |
| | &mx::gather_qmm, |
| | nb::arg(), |
| | nb::arg(), |
| | "scales"_a, |
| | "biases"_a = nb::none(), |
| | "lhs_indices"_a = nb::none(), |
| | "rhs_indices"_a = nb::none(), |
| | "transpose"_a = true, |
| | "group_size"_a = 64, |
| | "bits"_a = 4, |
| | "mode"_a = "affine", |
| | nb::kw_only(), |
| | "sorted_indices"_a = false, |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def gather_qmm(x: array, w: array, /, scales: array, biases: Optional[array] = None, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Perform quantized matrix multiplication with matrix-level gather. |
| | |
| | This operation is the quantized equivalent to :func:`gather_mm`. |
| | Similar to :func:`gather_mm`, the indices ``lhs_indices`` and |
| | ``rhs_indices`` contain flat indices along the batch dimensions (i.e. |
| | all but the last two dimensions) of ``x`` and ``w`` respectively. |
| | |
| | Note that ``scales`` and ``biases`` must have the same batch dimensions |
| | as ``w`` since they represent the same quantized matrix. |
| | |
| | Args: |
| | x (array): Input array |
| | w (array): Quantized matrix packed in unsigned integers |
| | scales (array): The scales to use per ``group_size`` elements of ``w`` |
| | biases (array, optional): The biases to use per ``group_size`` |
| | elements of ``w``. Default: ``None``. |
| | lhs_indices (array, optional): Integer indices for ``x``. Default: ``None``. |
| | rhs_indices (array, optional): Integer indices for ``w``. Default: ``None``. |
| | transpose (bool, optional): Defines whether to multiply with the |
| | transposed ``w`` or not, namely whether we are performing |
| | ``x @ w.T`` or ``x @ w``. Default: ``True``. |
| | group_size (int, optional): The size of the group in ``w`` that |
| | shares a scale and bias. Default: ``64``. |
| | bits (int, optional): The number of bits occupied by each element in |
| | ``w``. Default: ``4``. |
| | mode (str, optional): The quantization mode. Default: ``"affine"``. |
| | sorted_indices (bool, optional): May allow a faster implementation |
| | if the passed indices are sorted. Default: ``False``. |
| | |
| | Returns: |
| | array: The result of the multiplication of ``x`` with ``w`` |
| | after gathering using ``lhs_indices`` and ``rhs_indices``. |
| | )pbdoc"); |
| | m.def( |
| | "segmented_mm", |
| | &mx::segmented_mm, |
| | nb::arg(), |
| | nb::arg(), |
| | "segments"_a, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def segmented_mm(a: array, b: array, /, segments: array, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Perform a matrix multiplication but segment the inner dimension and |
| | save the result for each segment separately. |
| | |
| | Args: |
| | a (array): Input array of shape ``MxK``. |
| | b (array): Input array of shape ``KxN``. |
| | segments (array): The offsets into the inner dimension for each segment. |
| | |
| | Returns: |
| | array: The result per segment of shape ``MxN``. |
| | )pbdoc"); |
| | m.def( |
| | "tensordot", |
| | [](const mx::array& a, |
| | const mx::array& b, |
| | const std::variant<int, std::vector<std::vector<int>>>& axes, |
| | mx::StreamOrDevice s) { |
| | if (auto pv = std::get_if<int>(&axes); pv) { |
| | return mx::tensordot(a, b, *pv, s); |
| | } else { |
| | auto& x = std::get<std::vector<std::vector<int>>>(axes); |
| | if (x.size() != 2) { |
| | throw std::invalid_argument( |
| | "[tensordot] axes must be a list of two lists."); |
| | } |
| | return mx::tensordot(a, b, x[0], x[1], s); |
| | } |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | "axes"_a = 2, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def tensordot(a: array, b: array, /, axes: Union[int, list[Sequence[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Compute the tensor dot product along the specified axes. |
| | |
| | Args: |
| | a (array): Input array |
| | b (array): Input array |
| | axes (int or list(list(int)), optional): The number of dimensions to |
| | sum over. If an integer is provided, then sum over the last |
| | ``axes`` dimensions of ``a`` and the first ``axes`` dimensions of |
| | ``b``. If a list of lists is provided, then sum over the |
| | corresponding dimensions of ``a`` and ``b``. Default: 2. |
| | |
| | Returns: |
| | array: The tensor dot product. |
| | )pbdoc"); |
| | m.def( |
| | "inner", |
| | &mx::inner, |
| | nb::arg(), |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def inner(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Ordinary inner product of vectors for 1-D arrays, in higher dimensions a sum product over the last axes. |
| | |
| | Args: |
| | a (array): Input array |
| | b (array): Input array |
| | |
| | Returns: |
| | array: The inner product. |
| | )pbdoc"); |
| | m.def( |
| | "outer", |
| | &mx::outer, |
| | nb::arg(), |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def outer(a: array, b: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Compute the outer product of two 1-D arrays, if the array's passed are not 1-D a flatten op will be run beforehand. |
| | |
| | Args: |
| | a (array): Input array |
| | b (array): Input array |
| | |
| | Returns: |
| | array: The outer product. |
| | )pbdoc"); |
| | m.def( |
| | "tile", |
| | [](const mx::array& a, |
| | const std::variant<int, std::vector<int>>& reps, |
| | mx::StreamOrDevice s) { |
| | if (auto pv = std::get_if<int>(&reps); pv) { |
| | return mx::tile(a, {*pv}, s); |
| | } else { |
| | return mx::tile(a, std::get<std::vector<int>>(reps), s); |
| | } |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def tile(a: array, reps: Union[int, Sequence[int]], /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Construct an array by repeating ``a`` the number of times given by ``reps``. |
| | |
| | Args: |
| | a (array): Input array |
| | reps (int or list(int)): The number of times to repeat ``a`` along each axis. |
| | |
| | Returns: |
| | array: The tiled array. |
| | )pbdoc"); |
| | m.def( |
| | "addmm", |
| | &mx::addmm, |
| | nb::arg(), |
| | nb::arg(), |
| | nb::arg(), |
| | "alpha"_a = 1.0f, |
| | "beta"_a = 1.0f, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def addmm(c: array, a: array, b: array, /, alpha: float = 1.0, beta: float = 1.0, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Matrix multiplication with addition and optional scaling. |
| | |
| | Perform the (possibly batched) matrix multiplication of two arrays and add to the result |
| | with optional scaling factors. |
| | |
| | Args: |
| | c (array): Input array or scalar. |
| | a (array): Input array or scalar. |
| | b (array): Input array or scalar. |
| | alpha (float, optional): Scaling factor for the |
| | matrix product of ``a`` and ``b`` (default: ``1``) |
| | beta (float, optional): Scaling factor for ``c`` (default: ``1``) |
| | |
| | Returns: |
| | array: ``alpha * (a @ b) + beta * c`` |
| | )pbdoc"); |
| | m.def( |
| | "block_masked_mm", |
| | &mx::block_masked_mm, |
| | nb::arg(), |
| | nb::arg(), |
| | "block_size"_a = 64, |
| | "mask_out"_a = nb::none(), |
| | "mask_lhs"_a = nb::none(), |
| | "mask_rhs"_a = nb::none(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def block_masked_mm(a: array, b: array, /, block_size: int = 64, mask_out: Optional[array] = None, mask_lhs: Optional[array] = None, mask_rhs: Optional[array] = None, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Matrix multiplication with block masking. |
| | |
| | Perform the (possibly batched) matrix multiplication of two arrays and with blocks |
| | of size ``block_size x block_size`` optionally masked out. |
| | |
| | Assuming ``a`` with shape (..., `M`, `K`) and b with shape (..., `K`, `N`) |
| | |
| | * ``lhs_mask`` must have shape (..., :math:`\lceil` `M` / ``block_size`` :math:`\rceil`, :math:`\lceil` `K` / ``block_size`` :math:`\rceil`) |
| | |
| | * ``rhs_mask`` must have shape (..., :math:`\lceil` `K` / ``block_size`` :math:`\rceil`, :math:`\lceil` `N` / ``block_size`` :math:`\rceil`) |
| | |
| | * ``out_mask`` must have shape (..., :math:`\lceil` `M` / ``block_size`` :math:`\rceil`, :math:`\lceil` `N` / ``block_size`` :math:`\rceil`) |
| | |
| | Note: Only ``block_size=64`` and ``block_size=32`` are currently supported |
| | |
| | Args: |
| | a (array): Input array or scalar. |
| | b (array): Input array or scalar. |
| | block_size (int): Size of blocks to be masked. Must be ``32`` or ``64``. Default: ``64``. |
| | mask_out (array, optional): Mask for output. Default: ``None``. |
| | mask_lhs (array, optional): Mask for ``a``. Default: ``None``. |
| | mask_rhs (array, optional): Mask for ``b``. Default: ``None``. |
| | |
| | Returns: |
| | array: The output array. |
| | )pbdoc"); |
| | m.def( |
| | "gather_mm", |
| | &mx::gather_mm, |
| | nb::arg(), |
| | nb::arg(), |
| | "lhs_indices"_a = nb::none(), |
| | "rhs_indices"_a = nb::none(), |
| | nb::kw_only(), |
| | "sorted_indices"_a = false, |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Matrix multiplication with matrix-level gather. |
| | |
| | Performs a gather of the operands with the given indices followed by a |
| | (possibly batched) matrix multiplication of two arrays. This operation |
| | is more efficient than explicitly applying a :func:`take` followed by a |
| | :func:`matmul`. |
| | |
| | The indices ``lhs_indices`` and ``rhs_indices`` contain flat indices |
| | along the batch dimensions (i.e. all but the last two dimensions) of |
| | ``a`` and ``b`` respectively. |
| | |
| | For ``a`` with shape ``(A1, A2, ..., AS, M, K)``, ``lhs_indices`` |
| | contains indices from the range ``[0, A1 * A2 * ... * AS)`` |
| | |
| | For ``b`` with shape ``(B1, B2, ..., BS, M, K)``, ``rhs_indices`` |
| | contains indices from the range ``[0, B1 * B2 * ... * BS)`` |
| | |
| | If only one index is passed and it is sorted, the ``sorted_indices`` |
| | flag can be passed for a possible faster implementation. |
| | |
| | Args: |
| | a (array): Input array. |
| | b (array): Input array. |
| | lhs_indices (array, optional): Integer indices for ``a``. Default: ``None`` |
| | rhs_indices (array, optional): Integer indices for ``b``. Default: ``None`` |
| | sorted_indices (bool, optional): May allow a faster implementation |
| | if the passed indices are sorted. Default: ``False``. |
| | |
| | Returns: |
| | array: The output array. |
| | )pbdoc"); |
| | m.def( |
| | "diagonal", |
| | &mx::diagonal, |
| | "a"_a, |
| | "offset"_a = 0, |
| | "axis1"_a = 0, |
| | "axis2"_a = 1, |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def diagonal(a: array, offset: int = 0, axis1: int = 0, axis2: int = 1, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Return specified diagonals. |
| | |
| | If ``a`` is 2-D, then a 1-D array containing the diagonal at the given |
| | ``offset`` is returned. |
| | |
| | If ``a`` has more than two dimensions, then ``axis1`` and ``axis2`` |
| | determine the 2D subarrays from which diagonals are extracted. The new |
| | shape is the original shape with ``axis1`` and ``axis2`` removed and a |
| | new dimension inserted at the end corresponding to the diagonal. |
| | |
| | Args: |
| | a (array): Input array |
| | offset (int, optional): Offset of the diagonal from the main diagonal. |
| | Can be positive or negative. Default: ``0``. |
| | axis1 (int, optional): The first axis of the 2-D sub-arrays from which |
| | the diagonals should be taken. Default: ``0``. |
| | axis2 (int, optional): The second axis of the 2-D sub-arrays from which |
| | the diagonals should be taken. Default: ``1``. |
| | |
| | Returns: |
| | array: The diagonals of the array. |
| | )pbdoc"); |
| | m.def( |
| | "diag", |
| | &mx::diag, |
| | nb::arg(), |
| | "k"_a = 0, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def diag(a: array, /, k: int = 0, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Extract a diagonal or construct a diagonal matrix. |
| | If ``a`` is 1-D then a diagonal matrix is constructed with ``a`` on the |
| | :math:`k`-th diagonal. If ``a`` is 2-D then the :math:`k`-th diagonal is |
| | returned. |
| | |
| | Args: |
| | a (array): 1-D or 2-D input array. |
| | k (int, optional): The diagonal to extract or construct. |
| | Default: ``0``. |
| | |
| | Returns: |
| | array: The extracted diagonal or the constructed diagonal matrix. |
| | )pbdoc"); |
| | m.def( |
| | "trace", |
| | [](const mx::array& a, |
| | int offset, |
| | int axis1, |
| | int axis2, |
| | std::optional<mx::Dtype> dtype, |
| | mx::StreamOrDevice s) { |
| | if (!dtype.has_value()) { |
| | return mx::trace(a, offset, axis1, axis2, s); |
| | } |
| | return mx::trace(a, offset, axis1, axis2, dtype.value(), s); |
| | }, |
| | nb::arg(), |
| | "offset"_a = 0, |
| | "axis1"_a = 0, |
| | "axis2"_a = 1, |
| | "dtype"_a = nb::none(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def trace(a: array, /, offset: int = 0, axis1: int = 0, axis2: int = 1, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Return the sum along a specified diagonal in the given array. |
| | |
| | Args: |
| | a (array): Input array |
| | offset (int, optional): Offset of the diagonal from the main diagonal. |
| | Can be positive or negative. Default: ``0``. |
| | axis1 (int, optional): The first axis of the 2-D sub-arrays from which |
| | the diagonals should be taken. Default: ``0``. |
| | axis2 (int, optional): The second axis of the 2-D sub-arrays from which |
| | the diagonals should be taken. Default: ``1``. |
| | dtype (Dtype, optional): Data type of the output array. If |
| | unspecified the output type is inferred from the input array. |
| | |
| | Returns: |
| | array: Sum of specified diagonal. |
| | )pbdoc"); |
| | m.def( |
| | "atleast_1d", |
| | [](const nb::args& arys, mx::StreamOrDevice s) -> nb::object { |
| | if (arys.size() == 1) { |
| | return nb::cast(mx::atleast_1d(nb::cast<mx::array>(arys[0]), s)); |
| | } |
| | return nb::cast( |
| | mx::atleast_1d(nb::cast<std::vector<mx::array>>(arys), s)); |
| | }, |
| | "arys"_a, |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def atleast_1d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, list[array]]"), |
| | R"pbdoc( |
| | Convert all arrays to have at least one dimension. |
| | |
| | Args: |
| | *arys: Input arrays. |
| | stream (Union[None, Stream, Device], optional): The stream to execute the operation on. |
| | |
| | Returns: |
| | array or list(array): An array or list of arrays with at least one dimension. |
| | )pbdoc"); |
| | m.def( |
| | "atleast_2d", |
| | [](const nb::args& arys, mx::StreamOrDevice s) -> nb::object { |
| | if (arys.size() == 1) { |
| | return nb::cast(mx::atleast_2d(nb::cast<mx::array>(arys[0]), s)); |
| | } |
| | return nb::cast( |
| | mx::atleast_2d(nb::cast<std::vector<mx::array>>(arys), s)); |
| | }, |
| | "arys"_a, |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def atleast_2d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, list[array]]"), |
| | R"pbdoc( |
| | Convert all arrays to have at least two dimensions. |
| | |
| | Args: |
| | *arys: Input arrays. |
| | stream (Union[None, Stream, Device], optional): The stream to execute the operation on. |
| | |
| | Returns: |
| | array or list(array): An array or list of arrays with at least two dimensions. |
| | )pbdoc"); |
| | m.def( |
| | "atleast_3d", |
| | [](const nb::args& arys, mx::StreamOrDevice s) -> nb::object { |
| | if (arys.size() == 1) { |
| | return nb::cast(mx::atleast_3d(nb::cast<mx::array>(arys[0]), s)); |
| | } |
| | return nb::cast( |
| | mx::atleast_3d(nb::cast<std::vector<mx::array>>(arys), s)); |
| | }, |
| | "arys"_a, |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def atleast_3d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, list[array]]"), |
| | R"pbdoc( |
| | Convert all arrays to have at least three dimensions. |
| | |
| | Args: |
| | *arys: Input arrays. |
| | stream (Union[None, Stream, Device], optional): The stream to execute the operation on. |
| | |
| | Returns: |
| | array or list(array): An array or list of arrays with at least three dimensions. |
| | )pbdoc"); |
| | m.def( |
| | "issubdtype", |
| | [](const nb::object& d1, const nb::object& d2) { |
| | auto dispatch_second = [](const auto& t1, const auto& d2) { |
| | if (nb::isinstance<mx::Dtype>(d2)) { |
| | return mx::issubdtype(t1, nb::cast<mx::Dtype>(d2)); |
| | } else if (nb::isinstance<mx::Dtype::Category>(d2)) { |
| | return mx::issubdtype(t1, nb::cast<mx::Dtype::Category>(d2)); |
| | } else { |
| | throw std::invalid_argument( |
| | "[issubdtype] Received invalid type for second input."); |
| | } |
| | }; |
| | if (nb::isinstance<mx::Dtype>(d1)) { |
| | return dispatch_second(nb::cast<mx::Dtype>(d1), d2); |
| | } else if (nb::isinstance<mx::Dtype::Category>(d1)) { |
| | return dispatch_second(nb::cast<mx::Dtype::Category>(d1), d2); |
| | } else { |
| | throw std::invalid_argument( |
| | "[issubdtype] Received invalid type for first input."); |
| | } |
| | }, |
| | ""_a, |
| | ""_a, |
| | nb::sig( |
| | "def issubdtype(arg1: Union[Dtype, DtypeCategory], arg2: Union[Dtype, DtypeCategory]) -> bool"), |
| | R"pbdoc( |
| | Check if a :obj:`Dtype` or :obj:`DtypeCategory` is a subtype |
| | of another. |
| | |
| | Args: |
| | arg1 (Union[Dtype, DtypeCategory]: First dtype or category. |
| | arg2 (Union[Dtype, DtypeCategory]: Second dtype or category. |
| | |
| | Returns: |
| | bool: |
| | A boolean indicating if the first input is a subtype of the |
| | second input. |
| | |
| | Example: |
| | |
| | >>> ints = mx.array([1, 2, 3], dtype=mx.int32) |
| | >>> mx.issubdtype(ints.dtype, mx.integer) |
| | True |
| | >>> mx.issubdtype(ints.dtype, mx.floating) |
| | False |
| | |
| | >>> floats = mx.array([1, 2, 3], dtype=mx.float32) |
| | >>> mx.issubdtype(floats.dtype, mx.integer) |
| | False |
| | >>> mx.issubdtype(floats.dtype, mx.floating) |
| | True |
| | |
| | Similar types of different sizes are not subdtypes of each other: |
| | |
| | >>> mx.issubdtype(mx.float64, mx.float32) |
| | False |
| | >>> mx.issubdtype(mx.float32, mx.float64) |
| | False |
| | |
| | but both are subtypes of `floating`: |
| | |
| | >>> mx.issubdtype(mx.float64, mx.floating) |
| | True |
| | >>> mx.issubdtype(mx.float32, mx.floating) |
| | True |
| | |
| | For convenience, dtype-like objects are allowed too: |
| | |
| | >>> mx.issubdtype(mx.float32, mx.inexact) |
| | True |
| | >>> mx.issubdtype(mx.signedinteger, mx.floating) |
| | False |
| | )pbdoc"); |
| | m.def( |
| | "bitwise_and", |
| | [](const ScalarOrArray& a_, |
| | const ScalarOrArray& b_, |
| | mx::StreamOrDevice s) { |
| | auto [a, b] = to_arrays(a_, b_); |
| | return mx::bitwise_and(a, b, s); |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def bitwise_and(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise bitwise and. |
| | |
| | Take the bitwise and of two arrays with numpy-style broadcasting |
| | semantics. Either or both input arrays can also be scalars. |
| | |
| | Args: |
| | a (array): Input array or scalar. |
| | b (array): Input array or scalar. |
| | |
| | Returns: |
| | array: The bitwise and ``a & b``. |
| | )pbdoc"); |
| | m.def( |
| | "bitwise_or", |
| | [](const ScalarOrArray& a_, |
| | const ScalarOrArray& b_, |
| | mx::StreamOrDevice s) { |
| | auto [a, b] = to_arrays(a_, b_); |
| | return mx::bitwise_or(a, b, s); |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def bitwise_or(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise bitwise or. |
| | |
| | Take the bitwise or of two arrays with numpy-style broadcasting |
| | semantics. Either or both input arrays can also be scalars. |
| | |
| | Args: |
| | a (array): Input array or scalar. |
| | b (array): Input array or scalar. |
| | |
| | Returns: |
| | array: The bitwise or``a | b``. |
| | )pbdoc"); |
| | m.def( |
| | "bitwise_xor", |
| | [](const ScalarOrArray& a_, |
| | const ScalarOrArray& b_, |
| | mx::StreamOrDevice s) { |
| | auto [a, b] = to_arrays(a_, b_); |
| | return mx::bitwise_xor(a, b, s); |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def bitwise_xor(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise bitwise xor. |
| | |
| | Take the bitwise exclusive or of two arrays with numpy-style |
| | broadcasting semantics. Either or both input arrays can also be |
| | scalars. |
| | |
| | Args: |
| | a (array): Input array or scalar. |
| | b (array): Input array or scalar. |
| | |
| | Returns: |
| | array: The bitwise xor ``a ^ b``. |
| | )pbdoc"); |
| | m.def( |
| | "left_shift", |
| | [](const ScalarOrArray& a_, |
| | const ScalarOrArray& b_, |
| | mx::StreamOrDevice s) { |
| | auto [a, b] = to_arrays(a_, b_); |
| | return mx::left_shift(a, b, s); |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def left_shift(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise left shift. |
| | |
| | Shift the bits of the first input to the left by the second using |
| | numpy-style broadcasting semantics. Either or both input arrays can |
| | also be scalars. |
| | |
| | Args: |
| | a (array): Input array or scalar. |
| | b (array): Input array or scalar. |
| | |
| | Returns: |
| | array: The bitwise left shift ``a << b``. |
| | )pbdoc"); |
| | m.def( |
| | "right_shift", |
| | [](const ScalarOrArray& a_, |
| | const ScalarOrArray& b_, |
| | mx::StreamOrDevice s) { |
| | auto [a, b] = to_arrays(a_, b_); |
| | return mx::right_shift(a, b, s); |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def right_shift(a: Union[scalar, array], b: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise right shift. |
| | |
| | Shift the bits of the first input to the right by the second using |
| | numpy-style broadcasting semantics. Either or both input arrays can |
| | also be scalars. |
| | |
| | Args: |
| | a (array): Input array or scalar. |
| | b (array): Input array or scalar. |
| | |
| | Returns: |
| | array: The bitwise right shift ``a >> b``. |
| | )pbdoc"); |
| | m.def( |
| | "bitwise_invert", |
| | [](const ScalarOrArray& a_, mx::StreamOrDevice s) { |
| | auto a = to_array(a_); |
| | return mx::bitwise_invert(a, s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def bitwise_invert(a: Union[scalar, array], stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Element-wise bitwise inverse. |
| | |
| | Take the bitwise complement of the input. |
| | |
| | Args: |
| | a (array): Input array or scalar. |
| | |
| | Returns: |
| | array: The bitwise inverse ``~a``. |
| | )pbdoc"); |
| | m.def( |
| | "view", |
| | [](const ScalarOrArray& a, const mx::Dtype& dtype, mx::StreamOrDevice s) { |
| | return mx::view(to_array(a), dtype, s); |
| | }, |
| | nb::arg(), |
| | "dtype"_a, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def view(a: Union[scalar, array], dtype: Dtype, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | View the array as a different type. |
| | |
| | The output shape changes along the last axis if the input array's |
| | type and the input ``dtype`` do not have the same size. |
| | |
| | Note: the view op does not imply that the input and output arrays share |
| | their underlying data. The view only gaurantees that the binary |
| | representation of each element (or group of elements) is the same. |
| | |
| | Args: |
| | a (array): Input array or scalar. |
| | dtype (Dtype): The data type to change to. |
| | |
| | Returns: |
| | array: The array with the new type. |
| | )pbdoc"); |
| | m.def( |
| | "hadamard_transform", |
| | &mx::hadamard_transform, |
| | nb::arg(), |
| | "scale"_a = nb::none(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def hadamard_transform(a: array, scale: Optional[float] = None, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Perform the Walsh-Hadamard transform along the final axis. |
| | |
| | Equivalent to: |
| | |
| | .. code-block:: python |
| | |
| | from scipy.linalg import hadamard |
| | |
| | y = (hadamard(len(x)) @ x) * scale |
| | |
| | Supports sizes ``n = m*2^k`` for ``m`` in ``(1, 12, 20, 28)`` and ``2^k |
| | <= 8192`` for float32 and ``2^k <= 16384`` for float16/bfloat16. |
| | |
| | Args: |
| | a (array): Input array or scalar. |
| | scale (float): Scale the output by this factor. |
| | Defaults to ``1/sqrt(a.shape[-1])`` so that the Hadamard matrix is orthonormal. |
| | |
| | Returns: |
| | array: The transformed array. |
| | )pbdoc"); |
| | m.def( |
| | "einsum_path", |
| | [](const std::string& equation, const nb::args& operands) { |
| | auto arrays_list = nb::cast<std::vector<mx::array>>(operands); |
| | auto [path, str] = mx::einsum_path(equation, arrays_list); |
| | |
| | std::vector<nb::tuple> tuple_path; |
| | for (auto& p : path) { |
| | tuple_path.push_back(nb::tuple(nb::cast(p))); |
| | } |
| | return std::make_pair(tuple_path, str); |
| | }, |
| | "subscripts"_a, |
| | "operands"_a, |
| | nb::sig("def einsum_path(subscripts: str, *operands)"), |
| | R"pbdoc( |
| | |
| | Compute the contraction order for the given Einstein summation. |
| | |
| | Args: |
| | subscripts (str): The Einstein summation convention equation. |
| | *operands (array): The input arrays. |
| | |
| | Returns: |
| | tuple(list(tuple(int, int)), str): |
| | The einsum path and a string containing information about the |
| | chosen path. |
| | )pbdoc"); |
| | m.def( |
| | "einsum", |
| | [](const std::string& subscripts, |
| | const nb::args& operands, |
| | mx::StreamOrDevice s) { |
| | auto arrays_list = nb::cast<std::vector<mx::array>>(operands); |
| | return mx::einsum(subscripts, arrays_list, s); |
| | }, |
| | "subscripts"_a, |
| | "operands"_a, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def einsum(subscripts: str, *operands, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | |
| | Perform the Einstein summation convention on the operands. |
| | |
| | Args: |
| | subscripts (str): The Einstein summation convention equation. |
| | *operands (array): The input arrays. |
| | |
| | Returns: |
| | array: The output array. |
| | )pbdoc"); |
| | m.def( |
| | "roll", |
| | [](const mx::array& a, |
| | const std::variant<int, mx::Shape>& shift, |
| | const IntOrVec& axis, |
| | mx::StreamOrDevice s) { |
| | return std::visit( |
| | [&](auto sh, auto ax) -> mx::array { |
| | if constexpr (std::is_same_v<decltype(ax), std::monostate>) { |
| | return mx::roll(a, sh, s); |
| | } else { |
| | return mx::roll(a, sh, ax, s); |
| | } |
| | }, |
| | shift, |
| | axis); |
| | }, |
| | nb::arg(), |
| | "shift"_a, |
| | "axis"_a = nb::none(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def roll(a: array, shift: Union[int, Tuple[int]], axis: Union[None, int, Tuple[int]] = None, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Roll array elements along a given axis. |
| | |
| | Elements that are rolled beyond the end of the array are introduced at |
| | the beggining and vice-versa. |
| | |
| | If the axis is not provided the array is flattened, rolled and then the |
| | shape is restored. |
| | |
| | Args: |
| | a (array): Input array |
| | shift (int or tuple(int)): The number of places by which elements |
| | are shifted. If positive the array is rolled to the right, if |
| | negative it is rolled to the left. If an int is provided but the |
| | axis is a tuple then the same value is used for all axes. |
| | axis (int or tuple(int), optional): The axis or axes along which to |
| | roll the elements. |
| | )pbdoc"); |
| | m.def( |
| | "real", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::real(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def real(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Returns the real part of a complex array. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The real part of ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "imag", |
| | [](const ScalarOrArray& a, mx::StreamOrDevice s) { |
| | return mx::imag(to_array(a), s); |
| | }, |
| | nb::arg(), |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def imag(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Returns the imaginary part of a complex array. |
| | |
| | Args: |
| | a (array): Input array. |
| | |
| | Returns: |
| | array: The imaginary part of ``a``. |
| | )pbdoc"); |
| | m.def( |
| | "slice", |
| | [](const mx::array& a, |
| | const mx::array& start_indices, |
| | std::vector<int> axes, |
| | mx::Shape slice_size, |
| | mx::StreamOrDevice s) { |
| | return mx::slice( |
| | a, start_indices, std::move(axes), std::move(slice_size), s); |
| | }, |
| | nb::arg(), |
| | "start_indices"_a, |
| | "axes"_a, |
| | "slice_size"_a, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def slice(a: array, start_indices: array, axes: Sequence[int], slice_size: Sequence[int], *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Extract a sub-array from the input array. |
| | |
| | Args: |
| | a (array): Input array |
| | start_indices (array): The index location to start the slice at. |
| | axes (tuple(int)): The axes corresponding to the indices in ``start_indices``. |
| | slice_size (tuple(int)): The size of the slice. |
| | |
| | Returns: |
| | array: The sliced output array. |
| | |
| | Example: |
| | |
| | >>> a = mx.array([[1, 2, 3], [4, 5, 6]]) |
| | >>> mx.slice(a, start_indices=mx.array(1), axes=(0,), slice_size=(1, 2)) |
| | array([[4, 5]], dtype=int32) |
| | >>> |
| | >>> mx.slice(a, start_indices=mx.array(1), axes=(1,), slice_size=(2, 1)) |
| | array([[2], |
| | [5]], dtype=int32) |
| | )pbdoc"); |
| | m.def( |
| | "slice_update", |
| | [](const mx::array& src, |
| | const mx::array& update, |
| | const mx::array& start_indices, |
| | std::vector<int> axes, |
| | mx::StreamOrDevice s) { |
| | return mx::slice_update(src, update, start_indices, axes, s); |
| | }, |
| | nb::arg(), |
| | "update"_a, |
| | "start_indices"_a, |
| | "axes"_a, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def slice_update(a: array, update: array, start_indices: array, axes: Sequence[int], *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Update a sub-array of the input array. |
| | |
| | Args: |
| | a (array): The input array to update |
| | update (array): The update array. |
| | start_indices (array): The index location to start the slice at. |
| | axes (tuple(int)): The axes corresponding to the indices in ``start_indices``. |
| | |
| | Returns: |
| | array: The output array with the same shape and type as the input. |
| | |
| | Example: |
| | |
| | >>> a = mx.zeros((3, 3)) |
| | >>> mx.slice_update(a, mx.ones((1, 2)), start_indices=mx.array(1, 1), axes=(0, 1)) |
| | array([[0, 0, 0], |
| | [0, 1, 0], |
| | [0, 1, 0]], dtype=float32) |
| | )pbdoc"); |
| | m.def( |
| | "contiguous", |
| | &mx::contiguous, |
| | nb::arg(), |
| | "allow_col_major"_a = false, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def contiguous(a: array, /, allow_col_major: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Force an array to be row contiguous. Copy if necessary. |
| | |
| | Args: |
| | a (array): The input to make contiguous |
| | allow_col_major (bool): Consider column major as contiguous and don't copy |
| | |
| | Returns: |
| | array: The row or col contiguous output. |
| | )pbdoc"); |
| | m.def( |
| | "broadcast_shapes", |
| | [](const nb::args& shapes) { |
| | if (shapes.size() == 0) |
| | throw std::invalid_argument( |
| | "[broadcast_shapes] Must provide at least one shape."); |
| |
|
| | mx::Shape result = nb::cast<mx::Shape>(shapes[0]); |
| | for (size_t i = 1; i < shapes.size(); ++i) { |
| | if (!nb::isinstance<mx::Shape>(shapes[i]) && |
| | !nb::isinstance<nb::tuple>(shapes[i])) |
| | throw std::invalid_argument( |
| | "[broadcast_shapes] Expects a sequence of shapes (tuple or list of ints)."); |
| | result = mx::broadcast_shapes(result, nb::cast<mx::Shape>(shapes[i])); |
| | } |
| |
|
| | return nb::tuple(nb::cast(result)); |
| | }, |
| | nb::sig("def broadcast_shapes(*shapes: Sequence[int]) -> Tuple[int]"), |
| | R"pbdoc( |
| | Broadcast shapes. |
| | |
| | Returns the shape that results from broadcasting the supplied array shapes |
| | against each other. |
| | |
| | Args: |
| | *shapes (Sequence[int]): The shapes to broadcast. |
| | |
| | Returns: |
| | tuple: The broadcasted shape. |
| | |
| | Raises: |
| | ValueError: If the shapes cannot be broadcast. |
| | |
| | Example: |
| | >>> mx.broadcast_shapes((1,), (3, 1)) |
| | (3, 1) |
| | >>> mx.broadcast_shapes((6, 7), (5, 6, 1), (7,)) |
| | (5, 6, 7) |
| | >>> mx.broadcast_shapes((5, 1, 4), (1, 3, 1)) |
| | (5, 3, 4) |
| | )pbdoc"); |
| | m.def( |
| | "depends", |
| | [](const nb::object& inputs_, const nb::object& deps_) { |
| | bool return_vec = false; |
| | std::vector<mx::array> inputs; |
| | std::vector<mx::array> deps; |
| | if (nb::isinstance<mx::array>(inputs_)) { |
| | inputs = {nb::cast<mx::array>(inputs_)}; |
| | } else { |
| | return_vec = true; |
| | inputs = {nb::cast<std::vector<mx::array>>(inputs_)}; |
| | } |
| | if (nb::isinstance<mx::array>(deps_)) { |
| | deps = {nb::cast<mx::array>(deps_)}; |
| | } else { |
| | deps = {nb::cast<std::vector<mx::array>>(deps_)}; |
| | } |
| | auto out = depends(inputs, deps); |
| | if (return_vec) { |
| | return nb::cast(out); |
| | } else { |
| | return nb::cast(out[0]); |
| | } |
| | }, |
| | nb::arg(), |
| | nb::arg(), |
| | nb::sig( |
| | "def depends(inputs: Union[array, Sequence[array]], dependencies: Union[array, Sequence[array]])"), |
| | R"pbdoc( |
| | Insert dependencies between arrays in the graph. The outputs are |
| | identical to ``inputs`` but with dependencies on ``dependencies``. |
| | |
| | Args: |
| | inputs (array or Sequence[array]): The input array or arrays. |
| | dependencies (array or Sequence[array]): The array or arrays |
| | to insert dependencies on. |
| | |
| | Returns: |
| | array or Sequence[array]: The outputs which depend on dependencies. |
| | )pbdoc"); |
| | } |
| |
|