| | |
| |
|
| | #include <nanobind/nanobind.h> |
| | #include <nanobind/stl/optional.h> |
| | #include <nanobind/stl/variant.h> |
| | #include <nanobind/stl/vector.h> |
| | #include <numeric> |
| |
|
| | #include "mlx/fft.h" |
| | #include "mlx/ops.h" |
| | #include "python/src/small_vector.h" |
| |
|
| | namespace mx = mlx::core; |
| | namespace nb = nanobind; |
| | using namespace nb::literals; |
| |
|
| | void init_fft(nb::module_& parent_module) { |
| | auto m = parent_module.def_submodule( |
| | "fft", "mlx.core.fft: Fast Fourier Transforms."); |
| | m.def( |
| | "fft", |
| | [](const mx::array& a, |
| | const std::optional<int>& n, |
| | int axis, |
| | mx::StreamOrDevice s) { |
| | if (n.has_value()) { |
| | return mx::fft::fft(a, n.value(), axis, s); |
| | } else { |
| | return mx::fft::fft(a, axis, s); |
| | } |
| | }, |
| | "a"_a, |
| | "n"_a = nb::none(), |
| | "axis"_a = -1, |
| | "stream"_a = nb::none(), |
| | R"pbdoc( |
| | One dimensional discrete Fourier Transform. |
| | |
| | Args: |
| | a (array): The input array. |
| | n (int, optional): Size of the transformed axis. The |
| | corresponding axis in the input is truncated or padded with |
| | zeros to match ``n``. The default value is ``a.shape[axis]``. |
| | axis (int, optional): Axis along which to perform the FFT. The |
| | default is ``-1``. |
| | |
| | Returns: |
| | array: The DFT of the input along the given axis. |
| | )pbdoc"); |
| | m.def( |
| | "ifft", |
| | [](const mx::array& a, |
| | const std::optional<int>& n, |
| | int axis, |
| | mx::StreamOrDevice s) { |
| | if (n.has_value()) { |
| | return mx::fft::ifft(a, n.value(), axis, s); |
| | } else { |
| | return mx::fft::ifft(a, axis, s); |
| | } |
| | }, |
| | "a"_a, |
| | "n"_a = nb::none(), |
| | "axis"_a = -1, |
| | "stream"_a = nb::none(), |
| | R"pbdoc( |
| | One dimensional inverse discrete Fourier Transform. |
| | |
| | Args: |
| | a (array): The input array. |
| | n (int, optional): Size of the transformed axis. The |
| | corresponding axis in the input is truncated or padded with |
| | zeros to match ``n``. The default value is ``a.shape[axis]``. |
| | axis (int, optional): Axis along which to perform the FFT. The |
| | default is ``-1``. |
| | |
| | Returns: |
| | array: The inverse DFT of the input along the given axis. |
| | )pbdoc"); |
| | m.def( |
| | "fft2", |
| | [](const mx::array& a, |
| | const std::optional<mx::Shape>& n, |
| | const std::optional<std::vector<int>>& axes, |
| | mx::StreamOrDevice s) { |
| | if (axes.has_value() && n.has_value()) { |
| | return mx::fft::fftn(a, n.value(), axes.value(), s); |
| | } else if (axes.has_value()) { |
| | return mx::fft::fftn(a, axes.value(), s); |
| | } else if (n.has_value()) { |
| | throw std::invalid_argument( |
| | "[fft2] `axes` should not be `None` if `s` is not `None`."); |
| | } else { |
| | return mx::fft::fftn(a, s); |
| | } |
| | }, |
| | "a"_a, |
| | "s"_a = nb::none(), |
| | "axes"_a.none() = std::vector<int>{-2, -1}, |
| | "stream"_a = nb::none(), |
| | R"pbdoc( |
| | Two dimensional discrete Fourier Transform. |
| | |
| | Args: |
| | a (array): The input array. |
| | s (list(int), optional): Sizes of the transformed axes. The |
| | corresponding axes in the input are truncated or padded with |
| | zeros to match the sizes in ``s``. The default value is the |
| | sizes of ``a`` along ``axes``. |
| | axes (list(int), optional): Axes along which to perform the FFT. |
| | The default is ``[-2, -1]``. |
| | |
| | Returns: |
| | array: The DFT of the input along the given axes. |
| | )pbdoc"); |
| | m.def( |
| | "ifft2", |
| | [](const mx::array& a, |
| | const std::optional<mx::Shape>& n, |
| | const std::optional<std::vector<int>>& axes, |
| | mx::StreamOrDevice s) { |
| | if (axes.has_value() && n.has_value()) { |
| | return mx::fft::ifftn(a, n.value(), axes.value(), s); |
| | } else if (axes.has_value()) { |
| | return mx::fft::ifftn(a, axes.value(), s); |
| | } else if (n.has_value()) { |
| | throw std::invalid_argument( |
| | "[ifft2] `axes` should not be `None` if `s` is not `None`."); |
| | } else { |
| | return mx::fft::ifftn(a, s); |
| | } |
| | }, |
| | "a"_a, |
| | "s"_a = nb::none(), |
| | "axes"_a.none() = std::vector<int>{-2, -1}, |
| | "stream"_a = nb::none(), |
| | R"pbdoc( |
| | Two dimensional inverse discrete Fourier Transform. |
| | |
| | Args: |
| | a (array): The input array. |
| | s (list(int), optional): Sizes of the transformed axes. The |
| | corresponding axes in the input are truncated or padded with |
| | zeros to match the sizes in ``s``. The default value is the |
| | sizes of ``a`` along ``axes``. |
| | axes (list(int), optional): Axes along which to perform the FFT. |
| | The default is ``[-2, -1]``. |
| | |
| | Returns: |
| | array: The inverse DFT of the input along the given axes. |
| | )pbdoc"); |
| | m.def( |
| | "fftn", |
| | [](const mx::array& a, |
| | const std::optional<mx::Shape>& n, |
| | const std::optional<std::vector<int>>& axes, |
| | mx::StreamOrDevice s) { |
| | if (axes.has_value() && n.has_value()) { |
| | return mx::fft::fftn(a, n.value(), axes.value(), s); |
| | } else if (axes.has_value()) { |
| | return mx::fft::fftn(a, axes.value(), s); |
| | } else if (n.has_value()) { |
| | throw std::invalid_argument( |
| | "[fftn] `axes` should not be `None` if `s` is not `None`."); |
| | } else { |
| | return mx::fft::fftn(a, s); |
| | } |
| | }, |
| | "a"_a, |
| | "s"_a = nb::none(), |
| | "axes"_a = nb::none(), |
| | "stream"_a = nb::none(), |
| | R"pbdoc( |
| | n-dimensional discrete Fourier Transform. |
| | |
| | Args: |
| | a (array): The input array. |
| | s (list(int), optional): Sizes of the transformed axes. The |
| | corresponding axes in the input are truncated or padded with |
| | zeros to match the sizes in ``s``. The default value is the |
| | sizes of ``a`` along ``axes``. |
| | axes (list(int), optional): Axes along which to perform the FFT. |
| | The default is ``None`` in which case the FFT is over the last |
| | ``len(s)`` axes are or all axes if ``s`` is also ``None``. |
| | |
| | Returns: |
| | array: The DFT of the input along the given axes. |
| | )pbdoc"); |
| | m.def( |
| | "ifftn", |
| | [](const mx::array& a, |
| | const std::optional<mx::Shape>& n, |
| | const std::optional<std::vector<int>>& axes, |
| | mx::StreamOrDevice s) { |
| | if (axes.has_value() && n.has_value()) { |
| | return mx::fft::ifftn(a, n.value(), axes.value(), s); |
| | } else if (axes.has_value()) { |
| | return mx::fft::ifftn(a, axes.value(), s); |
| | } else if (n.has_value()) { |
| | throw std::invalid_argument( |
| | "[ifftn] `axes` should not be `None` if `s` is not `None`."); |
| | } else { |
| | return mx::fft::ifftn(a, s); |
| | } |
| | }, |
| | "a"_a, |
| | "s"_a = nb::none(), |
| | "axes"_a = nb::none(), |
| | "stream"_a = nb::none(), |
| | R"pbdoc( |
| | n-dimensional inverse discrete Fourier Transform. |
| | |
| | Args: |
| | a (array): The input array. |
| | s (list(int), optional): Sizes of the transformed axes. The |
| | corresponding axes in the input are truncated or padded with |
| | zeros to match the sizes in ``s``. The default value is the |
| | sizes of ``a`` along ``axes``. |
| | axes (list(int), optional): Axes along which to perform the FFT. |
| | The default is ``None`` in which case the FFT is over the last |
| | ``len(s)`` axes or all axes if ``s`` is also ``None``. |
| | |
| | Returns: |
| | array: The inverse DFT of the input along the given axes. |
| | )pbdoc"); |
| | m.def( |
| | "rfft", |
| | [](const mx::array& a, |
| | const std::optional<int>& n, |
| | int axis, |
| | mx::StreamOrDevice s) { |
| | if (n.has_value()) { |
| | return mx::fft::rfft(a, n.value(), axis, s); |
| | } else { |
| | return mx::fft::rfft(a, axis, s); |
| | } |
| | }, |
| | "a"_a, |
| | "n"_a = nb::none(), |
| | "axis"_a = -1, |
| | "stream"_a = nb::none(), |
| | R"pbdoc( |
| | One dimensional discrete Fourier Transform on a real input. |
| | |
| | The output has the same shape as the input except along ``axis`` in |
| | which case it has size ``n // 2 + 1``. |
| | |
| | Args: |
| | a (array): The input array. If the array is complex it will be silently |
| | cast to a real type. |
| | n (int, optional): Size of the transformed axis. The |
| | corresponding axis in the input is truncated or padded with |
| | zeros to match ``n``. The default value is ``a.shape[axis]``. |
| | axis (int, optional): Axis along which to perform the FFT. The |
| | default is ``-1``. |
| | |
| | Returns: |
| | array: The DFT of the input along the given axis. The output |
| | data type will be complex. |
| | )pbdoc"); |
| | m.def( |
| | "irfft", |
| | [](const mx::array& a, |
| | const std::optional<int>& n, |
| | int axis, |
| | mx::StreamOrDevice s) { |
| | if (n.has_value()) { |
| | return mx::fft::irfft(a, n.value(), axis, s); |
| | } else { |
| | return mx::fft::irfft(a, axis, s); |
| | } |
| | }, |
| | "a"_a, |
| | "n"_a = nb::none(), |
| | "axis"_a = -1, |
| | "stream"_a = nb::none(), |
| | R"pbdoc( |
| | The inverse of :func:`rfft`. |
| | |
| | The output has the same shape as the input except along ``axis`` in |
| | which case it has size ``n``. |
| | |
| | Args: |
| | a (array): The input array. |
| | n (int, optional): Size of the transformed axis. The |
| | corresponding axis in the input is truncated or padded with |
| | zeros to match ``n // 2 + 1``. The default value is |
| | ``a.shape[axis] // 2 + 1``. |
| | axis (int, optional): Axis along which to perform the FFT. The |
| | default is ``-1``. |
| | |
| | Returns: |
| | array: The real array containing the inverse of :func:`rfft`. |
| | )pbdoc"); |
| | m.def( |
| | "rfft2", |
| | [](const mx::array& a, |
| | const std::optional<mx::Shape>& n, |
| | const std::optional<std::vector<int>>& axes, |
| | mx::StreamOrDevice s) { |
| | if (axes.has_value() && n.has_value()) { |
| | return mx::fft::rfftn(a, n.value(), axes.value(), s); |
| | } else if (axes.has_value()) { |
| | return mx::fft::rfftn(a, axes.value(), s); |
| | } else if (n.has_value()) { |
| | throw std::invalid_argument( |
| | "[rfft2] `axes` should not be `None` if `s` is not `None`."); |
| | } else { |
| | return mx::fft::rfftn(a, s); |
| | } |
| | }, |
| | "a"_a, |
| | "s"_a = nb::none(), |
| | "axes"_a.none() = std::vector<int>{-2, -1}, |
| | "stream"_a = nb::none(), |
| | R"pbdoc( |
| | Two dimensional real discrete Fourier Transform. |
| | |
| | The output has the same shape as the input except along the dimensions in |
| | ``axes`` in which case it has sizes from ``s``. The last axis in ``axes`` is |
| | treated as the real axis and will have size ``s[-1] // 2 + 1``. |
| | |
| | Args: |
| | a (array): The input array. If the array is complex it will be silently |
| | cast to a real type. |
| | s (list(int), optional): Sizes of the transformed axes. The |
| | corresponding axes in the input are truncated or padded with |
| | zeros to match the sizes in ``s``. The default value is the |
| | sizes of ``a`` along ``axes``. |
| | axes (list(int), optional): Axes along which to perform the FFT. |
| | The default is ``[-2, -1]``. |
| | |
| | Returns: |
| | array: The real DFT of the input along the given axes. The output |
| | data type will be complex. |
| | )pbdoc"); |
| | m.def( |
| | "irfft2", |
| | [](const mx::array& a, |
| | const std::optional<mx::Shape>& n, |
| | const std::optional<std::vector<int>>& axes, |
| | mx::StreamOrDevice s) { |
| | if (axes.has_value() && n.has_value()) { |
| | return mx::fft::irfftn(a, n.value(), axes.value(), s); |
| | } else if (axes.has_value()) { |
| | return mx::fft::irfftn(a, axes.value(), s); |
| | } else if (n.has_value()) { |
| | throw std::invalid_argument( |
| | "[irfft2] `axes` should not be `None` if `s` is not `None`."); |
| | } else { |
| | return mx::fft::irfftn(a, s); |
| | } |
| | }, |
| | "a"_a, |
| | "s"_a = nb::none(), |
| | "axes"_a.none() = std::vector<int>{-2, -1}, |
| | "stream"_a = nb::none(), |
| | R"pbdoc( |
| | The inverse of :func:`rfft2`. |
| | |
| | Note the input is generally complex. The dimensions of the input |
| | specified in ``axes`` are padded or truncated to match the sizes |
| | from ``s``. The last axis in ``axes`` is treated as the real axis |
| | and will have size ``s[-1] // 2 + 1``. |
| | |
| | Args: |
| | a (array): The input array. |
| | s (list(int), optional): Sizes of the transformed axes. The |
| | corresponding axes in the input are truncated or padded with |
| | zeros to match the sizes in ``s`` except for the last axis |
| | which has size ``s[-1] // 2 + 1``. The default value is the |
| | sizes of ``a`` along ``axes``. |
| | axes (list(int), optional): Axes along which to perform the FFT. |
| | The default is ``[-2, -1]``. |
| | |
| | Returns: |
| | array: The real array containing the inverse of :func:`rfft2`. |
| | )pbdoc"); |
| | m.def( |
| | "rfftn", |
| | [](const mx::array& a, |
| | const std::optional<mx::Shape>& n, |
| | const std::optional<std::vector<int>>& axes, |
| | mx::StreamOrDevice s) { |
| | if (axes.has_value() && n.has_value()) { |
| | return mx::fft::rfftn(a, n.value(), axes.value(), s); |
| | } else if (axes.has_value()) { |
| | return mx::fft::rfftn(a, axes.value(), s); |
| | } else if (n.has_value()) { |
| | throw std::invalid_argument( |
| | "[rfftn] `axes` should not be `None` if `s` is not `None`."); |
| | } else { |
| | return mx::fft::rfftn(a, s); |
| | } |
| | }, |
| | "a"_a, |
| | "s"_a = nb::none(), |
| | "axes"_a = nb::none(), |
| | "stream"_a = nb::none(), |
| | R"pbdoc( |
| | n-dimensional real discrete Fourier Transform. |
| | |
| | The output has the same shape as the input except along the dimensions in |
| | ``axes`` in which case it has sizes from ``s``. The last axis in ``axes`` is |
| | treated as the real axis and will have size ``s[-1] // 2 + 1``. |
| | |
| | Args: |
| | a (array): The input array. If the array is complex it will be silently |
| | cast to a real type. |
| | s (list(int), optional): Sizes of the transformed axes. The |
| | corresponding axes in the input are truncated or padded with |
| | zeros to match the sizes in ``s``. The default value is the |
| | sizes of ``a`` along ``axes``. |
| | axes (list(int), optional): Axes along which to perform the FFT. |
| | The default is ``None`` in which case the FFT is over the last |
| | ``len(s)`` axes or all axes if ``s`` is also ``None``. |
| | |
| | Returns: |
| | array: The real DFT of the input along the given axes. The output |
| | )pbdoc"); |
| | m.def( |
| | "irfftn", |
| | [](const mx::array& a, |
| | const std::optional<mx::Shape>& n, |
| | const std::optional<std::vector<int>>& axes, |
| | mx::StreamOrDevice s) { |
| | if (axes.has_value() && n.has_value()) { |
| | return mx::fft::irfftn(a, n.value(), axes.value(), s); |
| | } else if (axes.has_value()) { |
| | return mx::fft::irfftn(a, axes.value(), s); |
| | } else if (n.has_value()) { |
| | throw std::invalid_argument( |
| | "[irfftn] `axes` should not be `None` if `s` is not `None`."); |
| | } else { |
| | return mx::fft::irfftn(a, s); |
| | } |
| | }, |
| | "a"_a, |
| | "s"_a = nb::none(), |
| | "axes"_a = nb::none(), |
| | "stream"_a = nb::none(), |
| | R"pbdoc( |
| | The inverse of :func:`rfftn`. |
| | |
| | Note the input is generally complex. The dimensions of the input |
| | specified in ``axes`` are padded or truncated to match the sizes |
| | from ``s``. The last axis in ``axes`` is treated as the real axis |
| | and will have size ``s[-1] // 2 + 1``. |
| | |
| | Args: |
| | a (array): The input array. |
| | s (list(int), optional): Sizes of the transformed axes. The |
| | corresponding axes in the input are truncated or padded with |
| | zeros to match the sizes in ``s``. The default value is the |
| | sizes of ``a`` along ``axes``. |
| | axes (list(int), optional): Axes along which to perform the FFT. |
| | The default is ``None`` in which case the FFT is over the last |
| | ``len(s)`` axes or all axes if ``s`` is also ``None``. |
| | |
| | Returns: |
| | array: The real array containing the inverse of :func:`rfftn`. |
| | )pbdoc"); |
| | m.def( |
| | "fftshift", |
| | [](const mx::array& a, |
| | const std::optional<std::vector<int>>& axes, |
| | mx::StreamOrDevice s) { |
| | if (axes.has_value()) { |
| | return mx::fft::fftshift(a, axes.value(), s); |
| | } else { |
| | return mx::fft::fftshift(a, s); |
| | } |
| | }, |
| | "a"_a, |
| | "axes"_a = nb::none(), |
| | "stream"_a = nb::none(), |
| | R"pbdoc( |
| | Shift the zero-frequency component to the center of the spectrum. |
| | |
| | Args: |
| | a (array): The input array. |
| | axes (list(int), optional): Axes over which to perform the shift. |
| | If ``None``, shift all axes. |
| | |
| | Returns: |
| | array: The shifted array with the same shape as the input. |
| | )pbdoc"); |
| | m.def( |
| | "ifftshift", |
| | [](const mx::array& a, |
| | const std::optional<std::vector<int>>& axes, |
| | mx::StreamOrDevice s) { |
| | if (axes.has_value()) { |
| | return mx::fft::ifftshift(a, axes.value(), s); |
| | } else { |
| | return mx::fft::ifftshift(a, s); |
| | } |
| | }, |
| | "a"_a, |
| | "axes"_a = nb::none(), |
| | "stream"_a = nb::none(), |
| | R"pbdoc( |
| | The inverse of :func:`fftshift`. While identical to :func:`fftshift` for even-length axes, |
| | the behavior differs for odd-length axes. |
| | |
| | Args: |
| | a (array): The input array. |
| | axes (list(int), optional): Axes over which to perform the inverse shift. |
| | If ``None``, shift all axes. |
| | |
| | Returns: |
| | array: The inverse-shifted array with the same shape as the input. |
| | )pbdoc"); |
| | } |
| |
|