| | |
| |
|
| | #include <nanobind/nanobind.h> |
| | #include <nanobind/stl/optional.h> |
| | #include <nanobind/stl/pair.h> |
| | #include <nanobind/stl/string.h> |
| | #include <nanobind/stl/tuple.h> |
| | #include <nanobind/stl/variant.h> |
| | #include <nanobind/stl/vector.h> |
| |
|
| | #include "mlx/fast.h" |
| | #include "mlx/ops.h" |
| | #include "python/src/small_vector.h" |
| | #include "python/src/utils.h" |
| |
|
| | namespace mx = mlx::core; |
| | namespace nb = nanobind; |
| | using namespace nb::literals; |
| |
|
| | namespace { |
| |
|
| | struct PyCustomKernelFunction { |
| | PyCustomKernelFunction(mx::fast::CustomKernelFunction kernel, const char* tag) |
| | : kernel_(std::move(kernel)), tag_(tag) {} |
| |
|
| | std::vector<mx::array> operator()( |
| | const std::vector<ScalarOrArray>& inputs_, |
| | const std::vector<mx::Shape>& output_shapes, |
| | const std::vector<mx::Dtype>& output_dtypes, |
| | std::tuple<int, int, int> grid, |
| | std::tuple<int, int, int> threadgroup, |
| | const std::optional<std::vector<std::pair<std::string, nb::object>>>& |
| | template_args_ = std::nullopt, |
| | std::optional<float> init_value = std::nullopt, |
| | bool verbose = false, |
| | mx::StreamOrDevice s = {}) const { |
| | std::vector<mx::array> inputs; |
| | for (const auto& value : inputs_) { |
| | inputs.push_back(to_array(value, std::nullopt)); |
| | } |
| | std::vector<std::pair<std::string, mx::fast::TemplateArg>> template_args; |
| | if (template_args_) { |
| | for (const auto& [name, value] : template_args_.value()) { |
| | |
| | if (nb::isinstance<bool>(value)) { |
| | bool bool_val = nb::cast<bool>(value); |
| | template_args.emplace_back(name, bool_val); |
| | } else if (nb::isinstance<int>(value)) { |
| | int int_val = nb::cast<int>(value); |
| | template_args.emplace_back(name, int_val); |
| | } else if (nb::isinstance<mx::Dtype>(value)) { |
| | mx::Dtype dtype = nb::cast<mx::Dtype>(value); |
| | template_args.emplace_back(name, dtype); |
| | } else { |
| | std::ostringstream msg; |
| | msg << tag_ |
| | << " Invalid template argument. Must be `mlx.core.Dtype`, `int` or `bool`."; |
| | throw std::invalid_argument(msg.str()); |
| | } |
| | } |
| | } |
| | return kernel_( |
| | inputs, |
| | output_shapes, |
| | output_dtypes, |
| | grid, |
| | threadgroup, |
| | template_args, |
| | init_value, |
| | verbose, |
| | s); |
| | } |
| |
|
| | mx::fast::CustomKernelFunction kernel_; |
| | const char* tag_; |
| | }; |
| |
|
| | } |
| |
|
| | void init_fast(nb::module_& parent_module) { |
| | auto m = |
| | parent_module.def_submodule("fast", "mlx.core.fast: fast operations"); |
| |
|
| | m.def( |
| | "rms_norm", |
| | &mx::fast::rms_norm, |
| | "x"_a, |
| | "weight"_a.none(), |
| | "eps"_a, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def rms_norm(x: array, weight: Optional[array], eps: float, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Root Mean Square normalization (RMS norm). |
| | |
| | The normalization is with respect to the last axis of the input ``x``. |
| | |
| | Args: |
| | x (array): Input array. |
| | weight (array, optional): A multiplicative weight to scale the result by. |
| | The ``weight`` should be one-dimensional with the same size |
| | as the last axis of ``x``. If set to ``None`` then no scaling happens. |
| | eps (float): A small additive constant for numerical stability. |
| | |
| | Returns: |
| | array: The output array. |
| | )pbdoc"); |
| |
|
| | m.def( |
| | "layer_norm", |
| | &mx::fast::layer_norm, |
| | "x"_a, |
| | "weight"_a.none(), |
| | "bias"_a.none(), |
| | "eps"_a, |
| | nb::kw_only(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def layer_norm(x: array, weight: Optional[array], bias: Optional[array], eps: float, *, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Layer normalization. |
| | |
| | The normalization is with respect to the last axis of the input ``x``. |
| | |
| | Args: |
| | x (array): Input array. |
| | weight (array, optional): A multiplicative weight to scale the result by. |
| | The ``weight`` should be one-dimensional with the same size |
| | as the last axis of ``x``. If set to ``None`` then no scaling happens. |
| | bias (array, optional): An additive offset to be added to the result. |
| | The ``bias`` should be one-dimensional with the same size |
| | as the last axis of ``x``. If set to ``None`` then no translation happens. |
| | eps (float): A small additive constant for numerical stability. |
| | |
| | Returns: |
| | array: The output array. |
| | )pbdoc"); |
| |
|
| | m.def( |
| | "rope", |
| | [](const mx::array& a, |
| | int dims, |
| | bool traditional, |
| | std::optional<float> base, |
| | float scale, |
| | const ScalarOrArray& offset, |
| | const std::optional<mx::array>& freqs , |
| | mx::StreamOrDevice s ) { |
| | return mx::fast::rope( |
| | a, dims, traditional, base, scale, to_array(offset), freqs, s); |
| | }, |
| | "a"_a, |
| | "dims"_a, |
| | nb::kw_only(), |
| | "traditional"_a, |
| | "base"_a.none(), |
| | "scale"_a, |
| | "offset"_a, |
| | "freqs"_a = nb::none(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def rope(a: array, dims: int, *, traditional: bool, base: Optional[float], scale: float, offset: Union[int, array], freqs: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | Apply rotary positional encoding to the input. |
| | |
| | The input is expected to be at least 3D with shape ``(B, *, T, D)`` where: |
| | * ``B`` is the batch size. |
| | * ``T`` is the sequence length. |
| | * ``D`` is the feature dimension. |
| | |
| | Args: |
| | a (array): The input array. |
| | dims (int): The feature dimensions to be rotated. If the input feature |
| | is larger than dims then the rest is left unchanged. |
| | traditional (bool): If set to ``True`` choose the traditional |
| | implementation which rotates consecutive dimensions. |
| | base (float, optional): The base used to compute angular frequency for |
| | each dimension in the positional encodings. Exactly one of ``base`` and |
| | ``freqs`` must be ``None``. |
| | scale (float): The scale used to scale the positions. |
| | offset (int or array): The position offset to start at. If an |
| | :obj:`array` is given it can be a scalar or vector of ``B`` |
| | offsets for each example in the batch. |
| | freqs (array, optional): Optional frequencies to use with RoPE. |
| | If set, the ``base`` parameter must be ``None``. Default: ``None``. |
| | |
| | Returns: |
| | array: The output array. |
| | )pbdoc"); |
| |
|
| | m.def( |
| | "scaled_dot_product_attention", |
| | [](const mx::array& queries, |
| | const mx::array& keys, |
| | const mx::array& values, |
| | const float scale, |
| | const std::variant<std::monostate, std::string, mx::array>& mask, |
| | const std::optional<mx::array>& sinks, |
| | mx::StreamOrDevice s) { |
| | bool has_mask = !std::holds_alternative<std::monostate>(mask); |
| | bool has_str_mask = |
| | has_mask && std::holds_alternative<std::string>(mask); |
| | bool has_arr_mask = has_mask && std::holds_alternative<mx::array>(mask); |
| |
|
| | if (has_mask) { |
| | if (has_str_mask) { |
| | auto mask_str = std::get<std::string>(mask); |
| | if (mask_str != "causal") { |
| | std::ostringstream msg; |
| | msg << "[scaled_dot_product_attention] invalid mask option '" |
| | << mask_str << "'. Must be 'causal', or an array."; |
| | throw std::invalid_argument(msg.str()); |
| | } |
| | return mx::fast::scaled_dot_product_attention( |
| | queries, keys, values, scale, mask_str, {}, sinks, s); |
| | } else { |
| | auto mask_arr = std::get<mx::array>(mask); |
| | return mx::fast::scaled_dot_product_attention( |
| | queries, keys, values, scale, "", {mask_arr}, sinks, s); |
| | } |
| |
|
| | } else { |
| | return mx::fast::scaled_dot_product_attention( |
| | queries, keys, values, scale, "", {}, sinks, s); |
| | } |
| | }, |
| | "q"_a, |
| | "k"_a, |
| | "v"_a, |
| | nb::kw_only(), |
| | "scale"_a, |
| | "mask"_a = nb::none(), |
| | "sinks"_a = nb::none(), |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, str, array] = None, sinks: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), |
| | R"pbdoc( |
| | A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``. |
| | |
| | Supports: |
| | |
| | * `Multi-Head Attention <https://arxiv.org/abs/1706.03762>`_ |
| | * `Grouped Query Attention <https://arxiv.org/abs/2305.13245>`_ |
| | * `Multi-Query Attention <https://arxiv.org/abs/1911.02150>`_ |
| | |
| | .. note:: |
| | |
| | * The softmax operation is performed in ``float32`` regardless of |
| | the input precision. |
| | * For Grouped Query Attention and Multi-Query Attention, the ``k`` |
| | and ``v`` inputs should not be pre-tiled to match ``q``. |
| | |
| | In the following the dimensions are given by: |
| | |
| | * ``B``: The batch size. |
| | * ``N_q``: The number of query heads. |
| | * ``N_kv``: The number of key and value heads. |
| | * ``T_q``: The number of queries per example. |
| | * ``T_kv``: The number of keys and values per example. |
| | * ``D``: The per-head dimension. |
| | |
| | Args: |
| | q (array): Queries with shape ``[B, N_q, T_q, D]``. |
| | k (array): Keys with shape ``[B, N_kv, T_kv, D]``. |
| | v (array): Values with shape ``[B, N_kv, T_kv, D]``. |
| | scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``). |
| | mask (str or array, optional): The mask to apply to the |
| | query-key scores. The mask can be an array or a string indicating |
| | the mask type. The only supported string type is ``"causal"``. If |
| | the mask is an array it can be a boolean or additive mask. The mask |
| | can have at most 4 dimensions and must be broadcast-compatible with |
| | the shape ``[B, N, T_q, T_kv]``. If an additive mask is given its |
| | type must promote to the promoted type of ``q``, ``k``, and ``v``. |
| | sinks (array, optional): An optional array of attention sinks. |
| | Default: ``None``. |
| | |
| | Returns: |
| | array: The output array. |
| | |
| | Example: |
| | |
| | .. code-block:: python |
| | |
| | B = 2 |
| | N_q = N_kv = 32 |
| | T_q = T_kv = 1000 |
| | D = 128 |
| | |
| | q = mx.random.normal(shape=(B, N_q, T_q, D)) |
| | k = mx.random.normal(shape=(B, N_kv, T_kv, D)) |
| | v = mx.random.normal(shape=(B, N_kv, T_kv, D)) |
| | scale = D ** -0.5 |
| | out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask="causal") |
| | )pbdoc"); |
| |
|
| | m.def( |
| | "metal_kernel", |
| | [](const std::string& name, |
| | const std::vector<std::string>& input_names, |
| | const std::vector<std::string>& output_names, |
| | const std::string& source, |
| | const std::string& header, |
| | bool ensure_row_contiguous, |
| | bool atomic_outputs) { |
| | auto kernel = mx::fast::metal_kernel( |
| | name, |
| | input_names, |
| | output_names, |
| | source, |
| | header, |
| | ensure_row_contiguous, |
| | atomic_outputs); |
| | return nb::cpp_function( |
| | PyCustomKernelFunction(std::move(kernel), "[metal_kernel]"), |
| | nb::kw_only(), |
| | "inputs"_a, |
| | "output_shapes"_a, |
| | "output_dtypes"_a, |
| | "grid"_a, |
| | "threadgroup"_a, |
| | "template"_a = nb::none(), |
| | "init_value"_a = nb::none(), |
| | "verbose"_a = false, |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def __call__(self, *, inputs: List[Union[scalar, array]], output_shapes: List[Sequence[int]], output_dtypes: List[Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[List[Tuple[str, Union[bool, int, Dtype]]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"), |
| | R"pbdoc( |
| | Run the kernel. |
| | |
| | Args: |
| | inputs (List[array]): The inputs passed to the Metal kernel. |
| | output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``. |
| | output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``. |
| | grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with. |
| | This will be passed to ``MTLComputeCommandEncoder::dispatchThreads``. |
| | threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use. |
| | This will be passed to ``MTLComputeCommandEncoder::dispatchThreads``. |
| | template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments. |
| | These will be added as template arguments to the kernel definition. Default: ``None``. |
| | init_value (float, optional): Optional value to use to initialize all of the output arrays. |
| | By default, output arrays are uninitialized. Default: ``None``. |
| | verbose (bool, optional): Whether to print the full generated source code of the kernel |
| | when it is run. Default: ``False``. |
| | stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``. |
| | |
| | Returns: |
| | List[array]: The list of output arrays.)pbdoc"); |
| | }, |
| | "name"_a, |
| | "input_names"_a, |
| | "output_names"_a, |
| | "source"_a, |
| | "header"_a = "", |
| | "ensure_row_contiguous"_a = true, |
| | "atomic_outputs"_a = false, |
| | R"pbdoc( |
| | A jit-compiled custom Metal kernel defined from a source string. |
| | |
| | Full documentation: :ref:`custom_metal_kernels`. |
| | |
| | Args: |
| | name (str): Name for the kernel. |
| | input_names (List[str]): The parameter names of the inputs in the |
| | function signature. |
| | output_names (List[str]): The parameter names of the outputs in the |
| | function signature. |
| | source (str): Source code. This is the body of a function in Metal, |
| | the function signature will be automatically generated. |
| | header (str): Header source code to include before the main function. |
| | Useful for helper functions or includes that should live outside of |
| | the main function body. |
| | ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous |
| | before the kernel runs. Default: ``True``. |
| | atomic_outputs (bool): Whether to use atomic outputs in the function signature |
| | e.g. ``device atomic<float>``. Default: ``False``. |
| | |
| | Returns: |
| | Callable ``metal_kernel``. |
| | |
| | Example: |
| | |
| | .. code-block:: python |
| | |
| | def exp_elementwise(a: mx.array): |
| | source = ''' |
| | uint elem = thread_position_in_grid.x; |
| | T tmp = inp[elem]; |
| | out[elem] = metal::exp(tmp); |
| | ''' |
| | |
| | kernel = mx.fast.metal_kernel( |
| | name="myexp", |
| | input_names=["inp"], |
| | output_names=["out"], |
| | source=source |
| | ) |
| | outputs = kernel( |
| | inputs=[a], |
| | template=[("T", mx.float32)], |
| | grid=(a.size, 1, 1), |
| | threadgroup=(256, 1, 1), |
| | output_shapes=[a.shape], |
| | output_dtypes=[a.dtype], |
| | verbose=True, |
| | ) |
| | return outputs[0] |
| | |
| | a = mx.random.normal(shape=(4, 16)).astype(mx.float16) |
| | b = exp_elementwise(a) |
| | assert mx.allclose(b, mx.exp(a)) |
| | )pbdoc"); |
| |
|
| | m.def( |
| | "cuda_kernel", |
| | [](const std::string& name, |
| | const std::vector<std::string>& input_names, |
| | const std::vector<std::string>& output_names, |
| | const std::string& source, |
| | const std::string& header, |
| | bool ensure_row_contiguous, |
| | int shared_mem) { |
| | auto kernel = mx::fast::cuda_kernel( |
| | name, |
| | input_names, |
| | output_names, |
| | source, |
| | header, |
| | ensure_row_contiguous, |
| | shared_mem); |
| | return nb::cpp_function( |
| | PyCustomKernelFunction(std::move(kernel), "[cuda_kernel]"), |
| | nb::kw_only(), |
| | "inputs"_a, |
| | "output_shapes"_a, |
| | "output_dtypes"_a, |
| | "grid"_a, |
| | "threadgroup"_a, |
| | "template"_a = nb::none(), |
| | "init_value"_a = nb::none(), |
| | "verbose"_a = false, |
| | "stream"_a = nb::none(), |
| | nb::sig( |
| | "def __call__(self, *, inputs: List[Union[scalar, array]], output_shapes: List[Sequence[int]], output_dtypes: List[Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[List[Tuple[str, Union[bool, int, Dtype]]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"), |
| | R"pbdoc( |
| | Run the kernel. |
| | |
| | Args: |
| | inputs (List[array]): The inputs passed to the CUDA kernel. |
| | output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``. |
| | output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``. |
| | grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with. |
| | For compatibility with :func:`metal_kernel` the grid is in threads and not in threadgroups. |
| | threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use. |
| | template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments. |
| | These will be added as template arguments to the kernel definition. Default: ``None``. |
| | init_value (float, optional): Optional value to use to initialize all of the output arrays. |
| | By default, output arrays are uninitialized. Default: ``None``. |
| | verbose (bool, optional): Whether to print the full generated source code of the kernel |
| | when it is run. Default: ``False``. |
| | stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``. |
| | |
| | Returns: |
| | List[array]: The list of output arrays.)pbdoc"); |
| | }, |
| | "name"_a, |
| | "input_names"_a, |
| | "output_names"_a, |
| | "source"_a, |
| | "header"_a = "", |
| | "ensure_row_contiguous"_a = true, |
| | "shared_memory"_a = 0, |
| | R"pbdoc( |
| | A jit-compiled custom CUDA kernel defined from a source string. |
| | |
| | This is the CUDA equivalent of :ref:`custom_metal_kernels`. |
| | |
| | Args: |
| | name (str): Name for the kernel. |
| | input_names (List[str]): The parameter names of the inputs in the |
| | function signature. |
| | output_names (List[str]): The parameter names of the outputs in the |
| | function signature. |
| | source (str): Source code. This is the body of a function in CUDA, |
| | the function signature will be automatically generated. |
| | header (str): Header source code to include before the main function. |
| | Useful for helper functions or includes that should live outside of |
| | the main function body. |
| | ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous |
| | before the kernel runs. Default: ``True``. |
| | shared_memory (int): The dynamic shared memory to request for the |
| | kernel. A value of 0 means no dynamic shared memory. Default: ``0``. |
| | |
| | Returns: |
| | Callable ``cuda_kernel``. |
| | |
| | Example: |
| | |
| | .. code-block:: python |
| | |
| | def exp_elementwise(a: mx.array): |
| | source = ''' |
| | auto elem = cooperative_groups::this_grid().thread_rank(); |
| | T tmp = inp[elem]; |
| | out[elem] = exp(tmp); |
| | ''' |
| | |
| | kernel = mx.fast.cuda_kernel( |
| | name="myexp", |
| | input_names=["inp"], |
| | output_names=["out"], |
| | source=source |
| | ) |
| | outputs = kernel( |
| | inputs=[a], |
| | template=[("T", mx.float32)], |
| | grid=(a.size, 1, 1), |
| | threadgroup=(256, 1, 1), |
| | output_shapes=[a.shape], |
| | output_dtypes=[a.dtype], |
| | verbose=True, |
| | ) |
| | return outputs[0] |
| | |
| | a = mx.random.normal(shape=(16, 16)).astype(mx.float16) |
| | b = exp_elementwise(a) |
| | assert mx.allclose(b, mx.exp(a)) |
| | )pbdoc"); |
| |
|
| | m.def( |
| | "precompiled_cuda_kernel", |
| | [](const std::string& name, |
| | const nb::bytes compiled_source, |
| | const std::vector<ScalarOrArray>& inputs_, |
| | const std::vector<mx::Shape>& output_shapes, |
| | const std::vector<mx::Dtype>& output_dtypes, |
| | const std::vector<nb::object>& scalars_, |
| | std::tuple<int, int, int> grid, |
| | std::tuple<int, int, int> threadgroup, |
| | int shared_memory, |
| | std::optional<float> init_value = std::nullopt, |
| | bool ensure_row_contiguous = false, |
| | mx::StreamOrDevice s = {}) { |
| | |
| | std::vector<mx::array> inputs; |
| | for (const auto& value : inputs_) { |
| | inputs.push_back(to_array(value, std::nullopt)); |
| | } |
| |
|
| | |
| | std::vector<mx::fast::ScalarArg> scalars; |
| | scalars.reserve(scalars_.size()); |
| | for (const auto& v : scalars_) { |
| | if (nb::isinstance<bool>(v)) { |
| | scalars.push_back(nb::cast<bool>(v)); |
| | } else if (nb::isinstance<int>(v)) { |
| | scalars.push_back(nb::cast<int>(v)); |
| | } else if (nb::isinstance<float>(v)) { |
| | scalars.push_back(nb::cast<float>(v)); |
| | } else { |
| | nb::object vtype = v.attr("__class__"); |
| | std::string vtype_name = |
| | nb::cast<std::string>(vtype.attr("__name__")); |
| | std::ostringstream msg; |
| | msg << "[precompiled_cuda_kernel] Invalid scalar argument type. " |
| | << "Received " << vtype_name |
| | << " but must be one of bool, int or float"; |
| | throw std::invalid_argument(msg.str()); |
| | } |
| | } |
| |
|
| | return mx::fast::precompiled_cuda_kernel( |
| | name, |
| | std::string( |
| | static_cast<const char*>(compiled_source.data()), |
| | compiled_source.size()), |
| | inputs, |
| | output_shapes, |
| | output_dtypes, |
| | scalars, |
| | grid, |
| | threadgroup, |
| | shared_memory, |
| | init_value, |
| | ensure_row_contiguous, |
| | s); |
| | }, |
| | nb::kw_only(), |
| | "name"_a, |
| | "compiled_source"_a, |
| | "inputs"_a, |
| | "output_shapes"_a, |
| | "output_dtypes"_a, |
| | "scalars"_a, |
| | "grid"_a, |
| | "threadgroup"_a, |
| | "shared_memory"_a = 0, |
| | "init_value"_a = nb::none(), |
| | "ensure_row_contiguous"_a = false, |
| | "stream"_a = nb::none(), |
| | R"pbdoc( |
| | Run a precompiled CUDA kernel defined from PTX or cubin. |
| | |
| | This op is still experimental and various parts of the API may change. |
| | |
| | Args: |
| | name (str): Name for the kernel |
| | compiled_source (bytes): The precompiled kernel in raw bytes. |
| | inputs (List[array]): The inputs passed to the CUDA kernel. |
| | output_shapes (List[Sequence[int]]): The list of shapes for each output. |
| | output_dtypes (List[Dtype]): The list of data types for each output. |
| | scalars (List[Union[bool, int, float]]): A list of scalar arguments to |
| | pass to the kernel. |
| | grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with. |
| | For compatibility with :func:`metal_kernel` the grid is in threads and not in threadblocks. |
| | threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use. |
| | shared_memory (int): The dynamic shared memory to request for the |
| | kernel. A value of 0 means no dynamic shared memory. Default: ``0``. |
| | init_value (float, optional): Optional value to use to initialize all of the output arrays. |
| | By default, output arrays are uninitialized. Default: ``None``. |
| | ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous |
| | before the kernel runs. Default: ``False``. |
| | stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``. |
| | )pbdoc"); |
| | } |
| |
|