| | |
| | |
| | from typing import Optional |
| |
|
| | import torch |
| | import triton |
| | import triton.language as tl |
| | from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time |
| |
|
| | from flash_attn.ops.triton.k_activations import ( |
| | gelu, |
| | gelu_approx, |
| | gelu_approx_grad, |
| | gelu_grad, |
| | squared_relu, |
| | squared_relu_grad, |
| | ) |
| |
|
| | |
| |
|
| |
|
| | def init_to_zero(name): |
| | return lambda nargs: nargs[name].zero_() |
| |
|
| |
|
| | def get_configs_io_bound(): |
| | configs = [] |
| | for num_stages in [2, 3, 4, 5, 6]: |
| | for block_m in [16, 32]: |
| | for block_k in [32, 64]: |
| | for block_n in [32, 64, 128, 256]: |
| | num_warps = 2 if block_n <= 64 else 4 |
| | configs.append( |
| | triton.Config( |
| | { |
| | "BLOCK_M": block_m, |
| | "BLOCK_N": block_n, |
| | "BLOCK_K": block_k, |
| | "SPLIT_K": 1, |
| | }, |
| | num_stages=num_stages, |
| | num_warps=num_warps, |
| | ) |
| | ) |
| | |
| | |
| | |
| | |
| | |
| | return configs |
| |
|
| |
|
| | @triton.autotune( |
| | configs=[ |
| | triton.Config( |
| | {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8 |
| | ), |
| | triton.Config( |
| | {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8 |
| | ), |
| | triton.Config( |
| | {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
| | ), |
| | triton.Config( |
| | {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
| | ), |
| | triton.Config( |
| | {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
| | ), |
| | triton.Config( |
| | {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
| | ), |
| | triton.Config( |
| | {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
| | ), |
| | triton.Config( |
| | {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
| | ), |
| | triton.Config( |
| | {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2 |
| | ), |
| | |
| | triton.Config( |
| | {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, |
| | num_stages=3, |
| | num_warps=8, |
| | ), |
| | triton.Config( |
| | {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, |
| | num_stages=3, |
| | num_warps=8, |
| | ), |
| | triton.Config( |
| | {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
| | ), |
| | triton.Config( |
| | {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
| | ), |
| | triton.Config( |
| | {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, |
| | num_stages=4, |
| | num_warps=4, |
| | ), |
| | triton.Config( |
| | {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
| | ), |
| | triton.Config( |
| | {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
| | ), |
| | triton.Config( |
| | {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
| | ), |
| | triton.Config( |
| | {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2 |
| | ), |
| | ] |
| | + get_configs_io_bound(), |
| | key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"], |
| | prune_configs_by={ |
| | "early_config_prune": early_config_prune, |
| | "perf_model": estimate_matmul_time, |
| | "top_k": 10, |
| | }, |
| | ) |
| | @triton.heuristics( |
| | { |
| | "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, |
| | } |
| | ) |
| | @triton.jit |
| | def kernel_fwd( |
| | C, |
| | ACT_INPUT, |
| | A, |
| | B, |
| | bias, |
| | |
| | M, |
| | N, |
| | K, |
| | CACHE_KEY_M, |
| | CACHE_KEY_N, |
| | CACHE_KEY_K, |
| | |
| | |
| | |
| | stride_cm, |
| | |
| | stride_am, |
| | stride_ak, |
| | stride_bn, |
| | stride_bk, |
| | |
| | BLOCK_M: tl.constexpr, |
| | GROUP_M: tl.constexpr, |
| | BLOCK_N: tl.constexpr, |
| | BLOCK_K: tl.constexpr, |
| | |
| | SPLIT_K: tl.constexpr, |
| | EVEN_K: tl.constexpr, |
| | A_ROWMAJOR: tl.constexpr, |
| | B_COLMAJOR: tl.constexpr, |
| | BIAS: tl.constexpr, |
| | SAVE_ACT_INPUT: tl.constexpr, |
| | ACTIVATION: tl.constexpr, |
| | ): |
| |
|
| | """ |
| | Kernel for computing Out = activation(A x W + C) |
| | - Input has shape (M, K) |
| | - Weight has shape (K, N) |
| | - Bias has shape (N,) |
| | - Output has shape (M, N) |
| | - ActInputs (optional) has shape (M, N) |
| | 'ActInputs' optionally saves the A x W + C intermediate for backward computations |
| | This kernel will consolidate over K |
| | """ |
| |
|
| | pid = tl.program_id(axis=0) |
| |
|
| | grid_m = (M + BLOCK_M - 1) // BLOCK_M |
| | grid_n = (N + BLOCK_N - 1) // BLOCK_N |
| | |
| | width = GROUP_M * grid_n |
| | group_id = pid // width |
| | group_size = min(grid_m - group_id * GROUP_M, GROUP_M) |
| | pid_m = group_id * GROUP_M + (pid % group_size) |
| | pid_n = (pid % width) // (group_size) |
| |
|
| | |
| | |
| | |
| | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) |
| | rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
| | |
| | ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) |
| | rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) |
| | rk = tl.arange(0, BLOCK_K) |
| |
|
| | if A_ROWMAJOR: |
| | A = A + (ram[:, None] * stride_am + rk[None, :]) |
| | else: |
| | A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) |
| | if B_COLMAJOR: |
| | B = B + (rk[:, None] + rbn[None, :] * stride_bn) |
| | else: |
| | B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) |
| |
|
| | acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
| |
|
| | for k in range(K, 0, -BLOCK_K): |
| | if EVEN_K: |
| | a = tl.load(A) |
| | b = tl.load(B) |
| | else: |
| | a = tl.load(A, mask=rk[None, :] < k, other=0.0) |
| | b = tl.load(B, mask=rk[:, None] < k, other=0.0) |
| | acc += tl.dot(a, b) |
| |
|
| | if A_ROWMAJOR: |
| | A += BLOCK_K |
| | else: |
| | A += BLOCK_K * stride_ak |
| | if B_COLMAJOR: |
| | B += BLOCK_K |
| | else: |
| | B += BLOCK_K * stride_bk |
| |
|
| | |
| | if BIAS: |
| | bias = tl.load(bias + rn, mask=rn < N, other=0.0).to(tl.float32) |
| | acc += bias[None, :] |
| |
|
| | |
| | if SAVE_ACT_INPUT: |
| | |
| | act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] |
| | tl.store(act_in_ptrs, acc) |
| |
|
| | |
| | if ACTIVATION == "gelu": |
| | acc = gelu(acc) |
| | elif ACTIVATION == "gelu_approx": |
| | acc = gelu_approx(acc) |
| | elif ACTIVATION == "squared_relu": |
| | acc = squared_relu(acc) |
| | |
| | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) |
| | rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
| |
|
| | |
| | |
| | C = C + rm[:, None] * stride_cm + rn[None, :] |
| | mask = (rm < M)[:, None] & (rn < N)[None, :] |
| | tl.store(C, acc) |
| |
|
| |
|
| | def triton_linear_act( |
| | x: torch.Tensor, |
| | weight: torch.Tensor, |
| | bias: Optional[torch.Tensor] = None, |
| | activation: str = "id", |
| | save_act_input: bool = False, |
| | ) -> torch.Tensor: |
| | """ |
| | Compute e = activation(x @ weight.T + bias). |
| | This wrapper kicks the `kernel_fwd` Triton kernel |
| | :param x: input tensor |
| | :param weight: weight matrix |
| | :param bias: an optional bias tensor |
| | :param activation: Activation name. Needs to be a Triton kernel. |
| | :param act_input: an optional tensor to save the activation inputs (for backward) |
| | :return: result tensor |
| | """ |
| | |
| | |
| | |
| |
|
| | assert activation in ["id", "gelu", "gelu_approx", "squared_relu"] |
| |
|
| | batch_shape, n = x.shape[:-1], x.shape[-1] |
| | batch_dim = batch_shape.numel() |
| | x_reshaped = x.reshape(batch_dim, n) |
| |
|
| | if x_reshaped.stride(0) > 1 and x_reshaped.stride(1) > 1: |
| | x_reshaped = x_reshaped.contiguous() |
| | if weight.stride(0) > 1 and weight.stride(1) > 1: |
| | weight = weight.contiguous() |
| | bias = bias.contiguous() if bias is not None else None |
| |
|
| | assert ( |
| | x.dtype == weight.dtype |
| | ), f"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}" |
| | if bias is not None: |
| | assert ( |
| | x.dtype == bias.dtype |
| | ), f"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}" |
| | assert ( |
| | x_reshaped.shape[1] == weight.shape[1] |
| | ), f"Incompatible dimensions: {x_reshaped.shape} - {weight.shape}" |
| |
|
| | assert ( |
| | bias is None or bias.shape[0] == weight.shape[0] |
| | ), "Incompatible dimensions in between weight and bias" |
| |
|
| | M, K = x_reshaped.shape |
| | N, K = weight.shape |
| |
|
| | output = torch.empty((M, N), device=x.device, dtype=x.dtype) |
| | act_input = torch.empty_like(output) if save_act_input else None |
| |
|
| | |
| | grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) |
| |
|
| | kernel_fwd[grid]( |
| | output, |
| | act_input, |
| | x_reshaped, |
| | weight, |
| | bias if bias is not None else x, |
| | M, |
| | N, |
| | K, |
| | M // 32, |
| | N // 32, |
| | K // 32, |
| | stride_cm=output.stride(0), |
| | |
| | stride_am=x_reshaped.stride(0), |
| | stride_ak=x_reshaped.stride(1), |
| | stride_bk=weight.stride(1), |
| | stride_bn=weight.stride(0), |
| | BIAS=bias is not None, |
| | SAVE_ACT_INPUT=save_act_input, |
| | ACTIVATION=activation, |
| | A_ROWMAJOR=x_reshaped.stride(1) == 1, |
| | B_COLMAJOR=weight.stride(1) == 1, |
| | GROUP_M=8, |
| | ) |
| |
|
| | if not save_act_input: |
| | return output.reshape(*batch_shape, output.shape[-1]) |
| | else: |
| | return ( |
| | output.reshape(*batch_shape, output.shape[-1]), |
| | act_input.reshape(*batch_shape, act_input.shape[-1]), |
| | ) |
| |
|
| |
|
| | @triton.autotune( |
| | configs=[ |
| | triton.Config( |
| | {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8 |
| | ), |
| | triton.Config( |
| | {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8 |
| | ), |
| | triton.Config( |
| | {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
| | ), |
| | triton.Config( |
| | {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
| | ), |
| | triton.Config( |
| | {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
| | ), |
| | triton.Config( |
| | {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
| | ), |
| | triton.Config( |
| | {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
| | ), |
| | triton.Config( |
| | {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
| | ), |
| | triton.Config( |
| | {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2 |
| | ), |
| | |
| | triton.Config( |
| | {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, |
| | num_stages=3, |
| | num_warps=8, |
| | ), |
| | triton.Config( |
| | {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, |
| | num_stages=3, |
| | num_warps=8, |
| | ), |
| | triton.Config( |
| | {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
| | ), |
| | triton.Config( |
| | {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
| | ), |
| | triton.Config( |
| | {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, |
| | num_stages=4, |
| | num_warps=4, |
| | ), |
| | triton.Config( |
| | {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
| | ), |
| | triton.Config( |
| | {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
| | ), |
| | triton.Config( |
| | {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
| | ), |
| | triton.Config( |
| | {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2 |
| | ), |
| | ] |
| | + get_configs_io_bound(), |
| | key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"], |
| | prune_configs_by={ |
| | "early_config_prune": early_config_prune, |
| | "perf_model": estimate_matmul_time, |
| | "top_k": 10, |
| | }, |
| | ) |
| | @triton.heuristics( |
| | { |
| | "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, |
| | } |
| | ) |
| | @triton.jit |
| | def kernel_bwd( |
| | C, |
| | ACT_INPUT, |
| | A, |
| | B, |
| | |
| | M, |
| | N, |
| | K, |
| | CACHE_KEY_M, |
| | CACHE_KEY_N, |
| | CACHE_KEY_K, |
| | |
| | |
| | |
| | stride_cm, |
| | |
| | stride_am, |
| | stride_ak, |
| | stride_bk, |
| | stride_bn, |
| | |
| | BLOCK_M: tl.constexpr, |
| | GROUP_M: tl.constexpr, |
| | BLOCK_N: tl.constexpr, |
| | BLOCK_K: tl.constexpr, |
| | |
| | SPLIT_K: tl.constexpr, |
| | EVEN_K: tl.constexpr, |
| | ACTIVATION: tl.constexpr, |
| | ): |
| |
|
| | """ |
| | Kernel for computing Out = activation(A x W + C) |
| | - Input has shape (M, K) |
| | - Weight has shape (K, N) |
| | - Output has shape (M, N) |
| | - ActInputs (optional) has shape (M, N) |
| | 'ActInputs' optionally saves the A x W + C intermediate for backward computations |
| | This kernel will consolidate over K |
| | """ |
| |
|
| | pid = tl.program_id(axis=0) |
| |
|
| | grid_m = (M + BLOCK_M - 1) // BLOCK_M |
| | grid_n = (N + BLOCK_N - 1) // BLOCK_N |
| | |
| | width = GROUP_M * grid_n |
| | group_id = pid // width |
| | group_size = min(grid_m - group_id * GROUP_M, GROUP_M) |
| | pid_m = group_id * GROUP_M + (pid % group_size) |
| | pid_n = (pid % width) // (group_size) |
| |
|
| | |
| | |
| | |
| | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) |
| | rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
| | |
| | ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) |
| | rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) |
| | rk = tl.arange(0, BLOCK_K) |
| |
|
| | A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) |
| | B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) |
| |
|
| | acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
| |
|
| | for k in range(K, 0, -BLOCK_K): |
| | if EVEN_K: |
| | a = tl.load(A) |
| | b = tl.load(B) |
| | else: |
| | a = tl.load(A, mask=rk[None, :] < k, other=0.0) |
| | b = tl.load(B, mask=rk[:, None] < k, other=0.0) |
| | acc += tl.dot(a, b) |
| |
|
| | A += BLOCK_K * stride_ak |
| | B += BLOCK_K * stride_bk |
| |
|
| | |
| | if ACTIVATION != "id": |
| | act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] |
| | act_input = tl.load(act_in_ptrs).to(acc.dtype) |
| | if ACTIVATION == "gelu": |
| | acc *= gelu_grad(act_input) |
| | elif ACTIVATION == "gelu_approx": |
| | acc *= gelu_approx_grad(act_input) |
| | elif ACTIVATION == "squared_relu": |
| | acc *= squared_relu_grad(act_input) |
| |
|
| | |
| | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) |
| | rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
| |
|
| | |
| | C = C + rm[:, None] * stride_cm + rn[None, :] |
| | mask = (rm < M)[:, None] & (rn < N)[None, :] |
| | tl.store(C, acc, mask=mask) |
| |
|
| |
|
| | def triton_dgrad_act( |
| | grad_output: torch.Tensor, |
| | weight: torch.Tensor, |
| | activation: str = "id", |
| | act_input: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| | """ |
| | Compute e = activation(grad_output @ weight + bias). |
| | This wrapper kicks the `kernel_fwd` Triton kernel |
| | :param grad_output: input tensor |
| | :param weight: weight matrix |
| | :param activation: Activation name. Needs to be a Triton kernel. |
| | :param act_input: an optional tensor to save the activation inputs (for backward) |
| | :return: result tensor |
| | """ |
| | assert activation in ["id", "gelu", "gelu_approx", "squared_relu"] |
| |
|
| | batch_shape, n = grad_output.shape[:-1], grad_output.shape[-1] |
| | batch_dim = batch_shape.numel() |
| | grad_output_reshaped = grad_output.reshape(batch_dim, n) |
| |
|
| | if grad_output_reshaped.stride(0) > 1 and grad_output_reshaped.stride(1) > 1: |
| | grad_output_reshaped = grad_output_reshaped.contiguous() |
| | if weight.stride(0) > 1 and weight.stride(1) > 1: |
| | weight = weight.contiguous() |
| |
|
| | assert ( |
| | grad_output.dtype == weight.dtype |
| | ), f"grad_output and weight must have the same dtype, got {grad_output.dtype} and {weight.dtype}" |
| | assert ( |
| | grad_output_reshaped.shape[1] == weight.shape[0] |
| | ), f"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}" |
| | if activation != "id": |
| | assert act_input is not None, f"act_input is required for activation {activation}" |
| |
|
| | |
| | M, K = grad_output_reshaped.shape |
| | K, N = weight.shape |
| |
|
| | grad_input = torch.empty((M, N), device=grad_output.device, dtype=grad_output.dtype) |
| |
|
| | |
| | grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) |
| |
|
| | kernel_bwd[grid]( |
| | grad_input, |
| | act_input, |
| | grad_output_reshaped, |
| | weight, |
| | M, |
| | N, |
| | K, |
| | M // 32, |
| | N // 32, |
| | K // 32, |
| | stride_cm=grad_input.stride(0), |
| | |
| | stride_am=grad_output_reshaped.stride(0), |
| | stride_ak=grad_output_reshaped.stride(1), |
| | stride_bk=weight.stride(0), |
| | stride_bn=weight.stride(1), |
| | ACTIVATION=activation, |
| | GROUP_M=8, |
| | ) |
| |
|
| | return grad_input.reshape(*batch_shape, grad_input.shape[-1]) |
| |
|