| | .. _custom_metal_kernels: |
| |
|
| | Custom Metal Kernels |
| | ==================== |
| |
|
| | MLX supports writing custom Metal kernels through the Python and C++ APIs. |
| |
|
| | Simple Example |
| | -------------- |
| |
|
| | .. currentmodule:: mlx.core |
| |
|
| | Let's write a custom kernel that computes ``exp`` elementwise: |
| |
|
| | .. code-block:: python |
| |
|
| | source = """ |
| | uint elem = thread_position_in_grid.x; |
| | T tmp = inp[elem]; |
| | out[elem] = metal::exp(tmp); |
| | """ |
| |
|
| | kernel = mx.fast.metal_kernel( |
| | name="myexp", |
| | input_names=["inp"], |
| | output_names=["out"], |
| | source=source, |
| | ) |
| |
|
| | def exp_elementwise(a: mx.array): |
| | outputs = kernel( |
| | inputs=[a], |
| | template=[("T", mx.float32)], |
| | grid=(a.size, 1, 1), |
| | threadgroup=(256, 1, 1), |
| | output_shapes=[a.shape], |
| | output_dtypes=[a.dtype], |
| | ) |
| | return outputs[0] |
| |
|
| | a = mx.random.normal(shape=(4, 16)).astype(mx.float16) |
| | b = exp_elementwise(a) |
| | assert mx.allclose(b, mx.exp(a)) |
| |
|
| | Every time you make a kernel, a new Metal library is created and possibly |
| | JIT compiled. To reduce the overhead from that, build the kernel once with |
| | :func:`fast.metal_kernel` and then use it many times. |
| |
|
| | .. note:: |
| | Only pass the body of the Metal kernel in ``source``. The function |
| | signature is generated automatically. |
| |
|
| | The full function signature will be generated using: |
| |
|
| | * The shapes/dtypes of ``inputs`` |
| | In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp`` |
| | so we will add ``const device float16_t* inp`` to the signature. |
| | ``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience if they are present |
| | in ``source``. |
| | * The list of ``output_dtypes`` |
| | In the above, ``out`` is an ``mx.array`` of type ``mx.float16`` |
| | so we add ``device float16_t* out``. |
| | * Template parameters passed using ``template`` |
| | In the above, ``template=[("T", mx.float32)]`` adds a template of ``template <typename T>`` to the function |
| | and instantiates the template with ``custom_kernel_myexp_float<float>``. |
| | Template parameters can be ``mx.core.Dtype``, ``int`` or ``bool``. |
| | * Metal attributes used in ``source`` such as ``[[thread_position_in_grid]]`` |
| | These will be added as function arguments. |
| | All the attributes defined in Table 5.8 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ are supported. |
| |
|
| | Putting this all together, the generated function signature for ``myexp`` is as follows: |
| |
|
| | .. code-block:: cpp |
| |
|
| | template <typename T> |
| | [[kernel]] void custom_kernel_myexp_float( |
| | const device float16_t* inp [[buffer(0)]], |
| | device float16_t* out [[buffer(1)]], |
| | uint3 thread_position_in_grid [[thread_position_in_grid]]) { |
| |
|
| | uint elem = thread_position_in_grid.x; |
| | T tmp = inp[elem]; |
| | out[elem] = metal::exp(tmp); |
| |
|
| | } |
| |
|
| | template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>; |
| |
|
| | Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads |
| | <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ |
| | function. This means we will launch ``mx.prod(grid)`` threads, subdivided into |
| | ``threadgroup`` size threadgroups. For optimal performance, each thread group |
| | dimension should be less than or equal to the corresponding grid dimension. |
| |
|
| | Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the |
| | generated code for debugging purposes. |
| |
|
| | Using Shape/Strides |
| | ------------------- |
| |
|
| | :func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which |
| | is ``True`` by default. This will copy the array inputs if needed |
| | before the kernel is launched to ensure that the memory layout is row |
| | contiguous. Generally this makes writing the kernel easier, since we don't |
| | have to worry about gaps or the ordering of the dims when indexing. |
| |
|
| | If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes |
| | ``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are |
| | present in ``source``. We can then use MLX's built in indexing utils to fetch |
| | the right elements for each thread. |
| |
|
| | Let's convert ``myexp`` above to support arbitrarily strided arrays without |
| | relying on a copy from ``ensure_row_contiguous``: |
| |
|
| | .. code-block:: python |
| | |
| | source = """ |
| | uint elem = thread_position_in_grid.x; |
| | // Utils from `mlx/backend/metal/kernels/utils.h` are automatically included |
| | uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim); |
| | T tmp = inp[loc]; |
| | // Output arrays are always row contiguous |
| | out[elem] = metal::exp(tmp); |
| | """ |
| |
|
| | kernel = mx.fast.metal_kernel( |
| | name="myexp_strided", |
| | input_names=["inp"], |
| | output_names=["out"], |
| | source=source, |
| | ensure_row_contiguous=False, |
| | ) |
| |
|
| | def exp_elementwise(a: mx.array): |
| | outputs = kernel( |
| | inputs=[a], |
| | template=[("T", mx.float32)], |
| | grid=(a.size, 1, 1), |
| | threadgroup=(256, 1, 1), |
| | output_shapes=[a.shape], |
| | output_dtypes=[a.dtype], |
| | ) |
| | return outputs[0] |
| |
|
| | a = mx.random.normal(shape=(4, 16)).astype(mx.float16) |
| | |
| | a = a[::2] |
| | b = exp_elementwise(a) |
| | assert mx.allclose(b, mx.exp(a)) |
| |
|
| | Complex Example |
| | ----------------------------- |
| |
|
| | Let's implement a more complex example: ``grid_sample`` in ``"bilinear"`` mode. |
| |
|
| | We'll start with the following MLX implementation using standard ops: |
| |
|
| | .. code-block:: python |
| |
|
| | def grid_sample_ref(x, grid): |
| | N, H_in, W_in, _ = x.shape |
| | ix = ((grid[..., 0] + 1) * W_in - 1) / 2 |
| | iy = ((grid[..., 1] + 1) * H_in - 1) / 2 |
| |
|
| | ix_nw = mx.floor(ix).astype(mx.int32) |
| | iy_nw = mx.floor(iy).astype(mx.int32) |
| |
|
| | ix_ne = ix_nw + 1 |
| | iy_ne = iy_nw |
| |
|
| | ix_sw = ix_nw |
| | iy_sw = iy_nw + 1 |
| |
|
| | ix_se = ix_nw + 1 |
| | iy_se = iy_nw + 1 |
| |
|
| | nw = (ix_se - ix) * (iy_se - iy) |
| | ne = (ix - ix_sw) * (iy_sw - iy) |
| | sw = (ix_ne - ix) * (iy - iy_ne) |
| | se = (ix - ix_nw) * (iy - iy_nw) |
| |
|
| | I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :] |
| | I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :] |
| | I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :] |
| | I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :] |
| |
|
| | mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1) |
| | mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1) |
| | mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1) |
| | mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1) |
| |
|
| | I_nw *= mask_nw[..., None] |
| | I_ne *= mask_ne[..., None] |
| | I_sw *= mask_sw[..., None] |
| | I_se *= mask_se[..., None] |
| |
|
| | output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se |
| |
|
| | return output |
| |
|
| | Now let's use :func:`custom_function` together with :func:`fast.metal_kernel` |
| | to write a fast GPU kernel for both the forward and backward passes. |
| |
|
| | First we'll implement the forward pass as a fused kernel: |
| |
|
| | .. code-block:: python |
| |
|
| | source = """ |
| | uint elem = thread_position_in_grid.x; |
| | int H = x_shape[1]; |
| | int W = x_shape[2]; |
| | int C = x_shape[3]; |
| | int gH = grid_shape[1]; |
| | int gW = grid_shape[2]; |
| | |
| | int w_stride = C; |
| | int h_stride = W * w_stride; |
| | int b_stride = H * h_stride; |
| | |
| | uint grid_idx = elem / C * 2; |
| | float ix = ((grid[grid_idx] + 1) * W - 1) / 2; |
| | float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2; |
| | |
| | int ix_nw = floor(ix); |
| | int iy_nw = floor(iy); |
| | |
| | int ix_ne = ix_nw + 1; |
| | int iy_ne = iy_nw; |
| | |
| | int ix_sw = ix_nw; |
| | int iy_sw = iy_nw + 1; |
| | |
| | int ix_se = ix_nw + 1; |
| | int iy_se = iy_nw + 1; |
| | |
| | T nw = (ix_se - ix) * (iy_se - iy); |
| | T ne = (ix - ix_sw) * (iy_sw - iy); |
| | T sw = (ix_ne - ix) * (iy - iy_ne); |
| | T se = (ix - ix_nw) * (iy - iy_nw); |
| | |
| | int batch_idx = elem / C / gH / gW * b_stride; |
| | int channel_idx = elem % C; |
| | int base_idx = batch_idx + channel_idx; |
| | |
| | T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride]; |
| | T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride]; |
| | T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride]; |
| | T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride]; |
| | |
| | I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0; |
| | I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0; |
| | I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0; |
| | I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0; |
| | |
| | out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se; |
| | """ |
| |
|
| | kernel = mx.fast.metal_kernel( |
| | name="grid_sample", |
| | input_names=["x", "grid"], |
| | output_names=["out"], |
| | source=source, |
| | ) |
| |
|
| | @mx.custom_function |
| | def grid_sample(x, grid): |
| |
|
| | assert x.ndim == 4, "`x` must be 4D." |
| | assert grid.ndim == 4, "`grid` must be 4D." |
| |
|
| | B, _, _, C = x.shape |
| | _, gN, gM, D = grid.shape |
| | out_shape = (B, gN, gM, C) |
| |
|
| | assert D == 2, "Last dim of `grid` must be size 2." |
| |
|
| | outputs = kernel( |
| | inputs=[x, grid], |
| | template=[("T", x.dtype)], |
| | output_shapes=[out_shape], |
| | output_dtypes=[x.dtype], |
| | grid=(np.prod(out_shape), 1, 1), |
| | threadgroup=(256, 1, 1), |
| | ) |
| | return outputs[0] |
| |
|
| | For a reasonably sized input such as: |
| |
|
| | .. code-block:: python |
| |
|
| | x.shape = (8, 1024, 1024, 64) |
| | grid.shape = (8, 256, 256, 2) |
| |
|
| | On an M1 Max, we see a big performance improvement: |
| |
|
| | ``55.7ms -> 6.7ms => 8x speed up`` |
| |
|
| | Grid Sample VJP |
| | --------------- |
| |
|
| | Since we decorated ``grid_sample`` with :func:`custom_function`, we can now |
| | define its custom vjp transform so MLX can differentiate it. |
| |
|
| | The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so |
| | requires a few extra :func:`fast.metal_kernel` features: |
| |
|
| | * ``init_value=0`` |
| | Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel. |
| |
|
| | * ``atomic_outputs=True`` |
| | Designate all of the kernel outputs as ``atomic`` in the function signature. |
| | This means we can use Metal's ``atomic`` features to simultaneously update the ``x_grad`` and ``grid_grad`` arrays from multiple threadgroups. |
| | See section 6.15 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ for more details. |
| |
|
| | We can then implement the backwards pass as follows: |
| |
|
| | .. code-block:: python |
| |
|
| | source = """ |
| | uint elem = thread_position_in_grid.x; |
| | int H = x_shape[1]; |
| | int W = x_shape[2]; |
| | int C = x_shape[3]; |
| | // Pad C to the nearest larger simdgroup size multiple |
| | int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup; |
| | |
| | int gH = grid_shape[1]; |
| | int gW = grid_shape[2]; |
| | |
| | int w_stride = C; |
| | int h_stride = W * w_stride; |
| | int b_stride = H * h_stride; |
| | |
| | uint grid_idx = elem / C_padded * 2; |
| | float ix = ((grid[grid_idx] + 1) * W - 1) / 2; |
| | float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2; |
| | |
| | int ix_nw = floor(ix); |
| | int iy_nw = floor(iy); |
| | |
| | int ix_ne = ix_nw + 1; |
| | int iy_ne = iy_nw; |
| | |
| | int ix_sw = ix_nw; |
| | int iy_sw = iy_nw + 1; |
| | |
| | int ix_se = ix_nw + 1; |
| | int iy_se = iy_nw + 1; |
| | |
| | T nw = (ix_se - ix) * (iy_se - iy); |
| | T ne = (ix - ix_sw) * (iy_sw - iy); |
| | T sw = (ix_ne - ix) * (iy - iy_ne); |
| | T se = (ix - ix_nw) * (iy - iy_nw); |
| | |
| | int batch_idx = elem / C_padded / gH / gW * b_stride; |
| | int channel_idx = elem % C_padded; |
| | int base_idx = batch_idx + channel_idx; |
| | |
| | T gix = T(0); |
| | T giy = T(0); |
| | if (channel_idx < C) { |
| | int cot_index = elem / C_padded * C + channel_idx; |
| | T cot = cotangent[cot_index]; |
| | if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) { |
| | int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride; |
| | atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed); |
| | |
| | T I_nw = x[offset]; |
| | gix -= I_nw * (iy_se - iy) * cot; |
| | giy -= I_nw * (ix_se - ix) * cot; |
| | } |
| | if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) { |
| | int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride; |
| | atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed); |
| | |
| | T I_ne = x[offset]; |
| | gix += I_ne * (iy_sw - iy) * cot; |
| | giy -= I_ne * (ix - ix_sw) * cot; |
| | } |
| | if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) { |
| | int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride; |
| | atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed); |
| | |
| | T I_sw = x[offset]; |
| | gix -= I_sw * (iy - iy_ne) * cot; |
| | giy += I_sw * (ix_ne - ix) * cot; |
| | } |
| | if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) { |
| | int offset = base_idx + iy_se * h_stride + ix_se * w_stride; |
| | atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed); |
| | |
| | T I_se = x[offset]; |
| | gix += I_se * (iy - iy_nw) * cot; |
| | giy += I_se * (ix - ix_nw) * cot; |
| | } |
| | } |
| | |
| | T gix_mult = W / 2; |
| | T giy_mult = H / 2; |
| | |
| | // Reduce across each simdgroup first. |
| | // This is much faster than relying purely on atomics. |
| | gix = simd_sum(gix); |
| | giy = simd_sum(giy); |
| | |
| | if (thread_index_in_simdgroup == 0) { |
| | atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed); |
| | atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed); |
| | } |
| | """ |
| | kernel = mx.fast.metal_kernel( |
| | name="grid_sample_grad", |
| | input_names=["x", "grid", "cotangent"], |
| | output_names=["x_grad", "grid_grad"], |
| | source=source, |
| | atomic_outputs=True, |
| | ) |
| |
|
| | @grid_sample.vjp |
| | def grid_sample_vjp(primals, cotangent, _): |
| | x, grid = primals |
| | B, _, _, C = x.shape |
| | _, gN, gM, D = grid.shape |
| |
|
| | assert D == 2, "Last dim of `grid` must be size 2." |
| |
|
| | |
| | |
| | simdgroup_size = 32 |
| | C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size |
| | grid_size = B * gN * gM * C_padded |
| | outputs = kernel( |
| | inputs=[x, grid, cotangent], |
| | template=[("T", x.dtype)], |
| | output_shapes=[x.shape, grid.shape], |
| | output_dtypes=[x.dtype, x.dtype], |
| | grid=(grid_size, 1, 1), |
| | threadgroup=(256, 1, 1), |
| | init_value=0, |
| | ) |
| | return outputs[0], outputs[1] |
| |
|
| | There's an even larger speed up for the vjp: |
| |
|
| | ``676.4ms -> 16.7ms => 40x speed up`` |
| |
|