| // Copyright © 2023-2024 Apple Inc. | |
| namespace nb = nanobind; | |
| using namespace nb::literals; | |
| NB_MODULE(_ext, m) { | |
| m.doc() = "Sample extension for MLX"; | |
| m.def( | |
| "axpby", | |
| &my_ext::axpby, | |
| "x"_a, | |
| "y"_a, | |
| "alpha"_a, | |
| "beta"_a, | |
| nb::kw_only(), | |
| "stream"_a = nb::none(), | |
| R"( | |
| Scale and sum two vectors element-wise | |
| ``z = alpha * x + beta * y`` | |
| Follows numpy style broadcasting between ``x`` and ``y`` | |
| Inputs are upcasted to floats if needed | |
| Args: | |
| x (array): Input array. | |
| y (array): Input array. | |
| alpha (float): Scaling factor for ``x``. | |
| beta (float): Scaling factor for ``y``. | |
| Returns: | |
| array: ``alpha * x + beta * y`` | |
| )"); | |
| } | |