| | import torch |
| | import triton |
| | import triton.language as tl |
| | from typing import Optional |
| |
|
| | BLOCK_M = 128 |
| | ALLOW_TF32 = True |
| |
|
| |
|
| |
|
| | @triton.jit |
| | def _compute_expert_block( |
| | E_idx, E_mask, |
| | M_in_idx, |
| | N_block, N_mask, |
| | X_ptr, stride_xm, stride_xk, |
| | W_ptr, stride_we, stride_wk, stride_wn, |
| | K, |
| | acc, |
| | no_k_mask, |
| | BLOCK_K, |
| | allow_tf32=True, |
| | ): |
| |
|
| | K_block = tl.arange(0, BLOCK_K) |
| | X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk |
| | W_blk_ptrs = W_ptr + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn + E_idx * stride_we |
| | iters = tl.cdiv(K, BLOCK_K) |
| |
|
| | for K_block_id in range(iters): |
| | if no_k_mask: |
| | x = tl.load(X_blk_ptrs, mask=E_mask[:, None]) |
| | w = tl.load(W_blk_ptrs, mask=N_mask[None, :]) |
| | else: |
| | K_mask = (K_block_id * BLOCK_K + K_block) < K |
| | x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :]) |
| | w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :]) |
| |
|
| | X_blk_ptrs += BLOCK_K * stride_xk |
| | W_blk_ptrs += BLOCK_K * stride_wk |
| | acc = tl.dot(x, w, acc, allow_tf32=allow_tf32) |
| | return acc |
| |
|
| |
|
| | def _scatter2scatter_configs(): |
| | return [ |
| | triton.Config({'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4), |
| | ] |
| |
|
| | @triton.autotune(configs=_scatter2scatter_configs(), key=['M', 'N', 'K'], ) |
| | @triton.heuristics({ |
| | "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0, |
| | "NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0, |
| | }) |
| | @triton.jit |
| | def _scatter2scatter( |
| | X_ptr, stride_xm: tl.constexpr, stride_xk: tl.constexpr, |
| | W_ptr, stride_we, stride_wk: tl.constexpr, stride_wn: tl.constexpr, |
| | Y_ptr, stride_ym: tl.constexpr, stride_yn: tl.constexpr, |
| | B_ptr, stride_be: tl.constexpr, stride_bn: tl.constexpr, |
| | grouped_idx_ptr, expert_idxs_ptr, |
| | |
| | FAN_OUT: tl.constexpr, |
| | M, K: tl.constexpr, N: tl.constexpr, E: tl.constexpr, |
| | BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, |
| | ACC_TYPE: tl.constexpr, |
| | |
| | allow_tf32: tl.constexpr, |
| | x_grouped: tl.constexpr, y_grouped: tl.constexpr, |
| | NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr |
| | ): |
| | pid = tl.program_id(axis=0) |
| |
|
| | N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N) |
| | M_block_id = pid // N_BLOCK_COUNT |
| | N_block_id = pid % N_BLOCK_COUNT |
| |
|
| | M_block = M_block_id * BLOCK_M + tl.arange(0, BLOCK_M) |
| | N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) |
| | N_mask = N_block < N |
| | M_boundary_mask = M_block < (FAN_OUT * M) |
| | E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_boundary_mask, other=E) |
| |
|
| | no_k_mask = K % BLOCK_K == 0 |
| |
|
| | acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) |
| | E_first_idx = tl.min(E_idxs) |
| | E_last_idx = tl.minimum(tl.max(E_idxs), E - 1) |
| | M_idx = tl.load(grouped_idx_ptr + M_block, mask=M_boundary_mask).to(tl.int32) |
| | for E_idx in range(E_first_idx, E_last_idx + 1): |
| | E_mask = E_idxs == E_idx |
| | E_M_idx = M_idx |
| | if x_grouped: |
| | M_in_idx = M_block |
| | else: |
| | M_in_idx = E_M_idx // FAN_OUT |
| | acc = _compute_expert_block( |
| | E_idx, E_mask, |
| | M_in_idx, N_block, N_mask, |
| | X_ptr, stride_xm, stride_xk, |
| | W_ptr, stride_we, stride_wk, stride_wn, |
| | K, |
| | acc, |
| | no_k_mask, |
| | BLOCK_K, |
| | allow_tf32=allow_tf32, |
| | ) |
| |
|
| | if B_ptr is not None: |
| | B_blk_ptrs = B_ptr + E_idxs[:, None] * stride_be + N_block[None, :] * stride_bn |
| | acc += tl.load(B_blk_ptrs, mask=M_boundary_mask[:, None] & N_mask[None, :]) |
| |
|
| | if y_grouped: |
| | M_out_idx = M_block |
| | else: |
| | M_out_idx = M_idx |
| | Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn) |
| | tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :]) |
| |
|
| | def scatter2scatter(X, W, sorted_expert_idxs, sorted_scattered_idxs, k, |
| | b=None, |
| | x_grouped=False, y_grouped=False, |
| | out=None): |
| | assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) |
| | assert sorted_scattered_idxs.size(0) == X.size(0) * k |
| | |
| | y_dim = W.size(-1) |
| | L_scattered = sorted_expert_idxs.size(0) |
| | if out is None: |
| | output = torch.empty((L_scattered, y_dim), device=X.device, dtype=X.dtype) |
| | else: |
| | assert out.size(0) == L_scattered and out.size(1) == y_dim |
| | output = out |
| |
|
| | scatter2scatter_compileable(output, W, X, k, sorted_expert_idxs, sorted_scattered_idxs, |
| | b, x_grouped, y_grouped) |
| | return output |
| |
|
| |
|
| | @torch.library.custom_op("scattermoe::scatter2scatter", mutates_args={"output"}) |
| | def scatter2scatter_compileable( |
| | output: torch.Tensor, |
| | W: torch.Tensor, |
| | X: torch.Tensor, |
| | k: int, |
| | sorted_expert_idxs: torch.Tensor, |
| | sorted_scattered_idxs: torch.Tensor, |
| | b: Optional[torch.Tensor], |
| | x_grouped: bool, y_grouped: bool) -> None: |
| | def grid(META): |
| | grid_num = ( |
| | triton.cdiv(sorted_expert_idxs.size(0), META["BLOCK_M"]) * |
| | triton.cdiv(META['N'], META['BLOCK_N']), |
| | ) |
| | return grid_num |
| |
|
| | if b is None: |
| | b = None |
| | stride_be = stride_bk = 0 |
| | else: |
| | stride_be, stride_bk = b.stride() |
| |
|
| | _scatter2scatter[grid]( |
| | |
| | X, X.stride(0), X.stride(1), |
| | |
| | W, W.stride(0), W.stride(1), W.stride(2), |
| | |
| | output, output.stride(0), output.stride(1), |
| | |
| | b, stride_be, stride_bk, |
| | grouped_idx_ptr=sorted_scattered_idxs, |
| | expert_idxs_ptr=sorted_expert_idxs, |
| | |
| | FAN_OUT=k, |
| | M=X.size(0), |
| | K=X.size(1), |
| | N=output.size(1), E=W.size(0), |
| | BLOCK_M=BLOCK_M, |
| | ACC_TYPE=tl.float32, |
| | allow_tf32=ALLOW_TF32, |
| | x_grouped=x_grouped, y_grouped=y_grouped, |
| | ) |
| |
|
| |
|
| | def _config_XtY(): |
| | return [ |
| | triton.Config({'BLOCK_N': 128, 'BLOCK_K': 128, 'BLOCK_M': 32}, num_stages=4, num_warps=4), |
| | ] |
| |
|
| | def group_bwd_W(DY, X, expert_offsets, E, has_bias=False): |
| | DWt = torch.zeros((E, DY.size(-1), X.size(-1)), device=DY.device, dtype=DY.dtype) |
| | DW = DWt.permute(0, 2, 1) |
| | if has_bias: |
| | Db = torch.zeros((E, DY.size(-1)), device=DY.device, dtype=DY.dtype) |
| | else: |
| | Db = None |
| | groupXtY_compileable(E, DW, Db, DY, X, expert_offsets) |
| | return DW, Db |
| |
|
| |
|
| | @torch.library.custom_op("scattermoe::groupXtY", mutates_args={"DW"}) |
| | def groupXtY_compileable( |
| | E: int, |
| | DW: torch.Tensor, |
| | Db: Optional[torch.Tensor], |
| | DY: torch.Tensor, |
| | X: torch.Tensor, |
| | expert_offsets: torch.Tensor) -> None: |
| | def grid(META): |
| | grid = ( |
| | E * triton.cdiv(META['K'], META['BLOCK_K']), |
| | triton.cdiv(META['N'], META['BLOCK_N']), |
| | ) |
| | return grid |
| | |
| | if Db is None: |
| | stride_dbe = 0 |
| | stride_dbn = 0 |
| | else: |
| | stride_dbe, stride_dbn = Db.stride() |
| |
|
| | _groupXtY[grid]( |
| | |
| | DY, DY.stride(0), DY.stride(1), |
| | |
| | X, X.stride(0), X.stride(1), |
| | |
| | DW, DW.stride(0), DW.stride(1), DW.stride(2), |
| | |
| | Db, stride_dbe, stride_dbn, |
| | |
| | expert_offsets, |
| | |
| | M=DY.size(0), N=DY.size(-1), K=X.size(-1), |
| | |
| | ACC_TYPE=tl.float32, |
| | allow_tf32=ALLOW_TF32 |
| | ) |
| |
|
| |
|
| | @triton.autotune(configs=_config_XtY(), key=['M', 'N', 'K'], ) |
| | @triton.heuristics({ |
| | "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0, |
| | "NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0, |
| | }) |
| | @triton.jit |
| | def _groupXtY( |
| | DY_ptr, stride_dym, stride_dyk, |
| | X_ptr, stride_xm, stride_xn, |
| | DW_ptr, stride_dwe, stride_dwk, stride_dwn, |
| | Db_ptr, stride_dbe, stride_dbn, |
| | expert_offsets_ptr, |
| | M, K: tl.constexpr, N: tl.constexpr, |
| | BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, |
| | ACC_TYPE: tl.constexpr, |
| | allow_tf32: tl.constexpr, |
| | NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr |
| | ): |
| | pid0 = tl.program_id(axis=0) |
| | pid1 = tl.program_id(axis=1) |
| | num0 = tl.num_programs(0) |
| | num1 = tl.num_programs(1) |
| | |
| | pid0, pid1 = tl.swizzle2d(pid0, pid1, num0, num1, 4) |
| |
|
| | K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K) |
| | E_idx = pid0 // K_BLOCK_COUNT |
| | K_block_id = pid0 % K_BLOCK_COUNT |
| | N_block_id = pid1 |
| |
|
| | if E_idx == 0: |
| | start_idx = 0 |
| | else: |
| | start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32) |
| | end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32) |
| |
|
| |
|
| | if end_idx > start_idx: |
| | M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M) |
| |
|
| | K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K) |
| | K_mask = K_block < K |
| | K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K) |
| |
|
| | N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) |
| | N_mask = N_block < N |
| | N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N) |
| |
|
| | M_idxs = M_block |
| | xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm |
| | dy_blk_ptrs = DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk |
| | if (Db_ptr is not None) and (K_block_id == 0): |
| | _xty_and_bias( |
| | E_idx, start_idx, end_idx, |
| | M_block, |
| | K_block, K_mask, N_block, N_mask, |
| | dy_blk_ptrs, stride_dym, |
| | xt_blk_ptrs, stride_xm, |
| | DW_ptr, stride_dwe, stride_dwk, stride_dwn, |
| | Db_ptr, stride_dbe, stride_dbn, |
| | BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE, |
| | allow_tf32, NO_K_MASK, NO_N_MASK, |
| | compute_bias=True |
| | ) |
| | else: |
| | _xty_and_bias( |
| | E_idx, start_idx, end_idx, |
| | M_block, |
| | K_block, K_mask, N_block, N_mask, |
| | dy_blk_ptrs, stride_dym, |
| | xt_blk_ptrs, stride_xm, |
| | DW_ptr, stride_dwe, stride_dwk, stride_dwn, |
| | Db_ptr, stride_dbe, stride_dbn, |
| | BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE, |
| | allow_tf32, NO_K_MASK, NO_N_MASK, |
| | compute_bias=False |
| | ) |
| |
|
| |
|
| | @triton.jit |
| | def _xty_and_bias( |
| | E_idx, start_idx, end_idx, |
| | M_block, |
| | K_block, K_mask, N_block, N_mask, |
| | dy_blk_ptrs, stride_dym, |
| | xt_blk_ptrs, stride_xm, |
| | DW_ptr, stride_dwe, stride_dwk, stride_dwn, |
| | Db_ptr, stride_dbe, stride_dbn, |
| | BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE, |
| | allow_tf32, NO_K_MASK, NO_N_MASK, |
| | compute_bias: tl.constexpr |
| | ): |
| |
|
| | if compute_bias: |
| | db_acc = tl.zeros((BLOCK_N,), dtype=ACC_TYPE) |
| | else: |
| | db_acc = None |
| |
|
| | acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE) |
| | iters = tl.cdiv(end_idx - start_idx, BLOCK_M) |
| | for i in range(0, iters): |
| | M_mask = (i * BLOCK_M + M_block) < end_idx |
| | if NO_K_MASK: |
| | xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :]) |
| | else: |
| | xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :]) |
| | if NO_N_MASK: |
| | dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None]) |
| | else: |
| | dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :]) |
| | |
| | acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32) |
| |
|
| | xt_blk_ptrs += BLOCK_M * stride_xm |
| | dy_blk_ptrs += BLOCK_M * stride_dym |
| |
|
| | if compute_bias: |
| | db_acc += tl.sum(dy, axis=0) |
| |
|
| | DW_blk_ptrs = DW_ptr + E_idx * stride_dwe + K_block[:, None] * stride_dwk + N_block[None, :] * stride_dwn |
| | acc = acc.to(DW_blk_ptrs.dtype.element_ty) |
| | tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :]) |
| | if compute_bias: |
| | Db_blk_ptrs = Db_ptr + E_idx * stride_dbe + N_block * stride_dbn |
| | tl.store(Db_blk_ptrs, db_acc, mask=N_mask) |
| |
|
| |
|
| |
|
| | def _config_grouping(): |
| | return [ |
| | triton.Config({'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=4, num_warps=4), |
| | |
| | |
| | ] |
| |
|
| | def group(A, sorted_expert_idxs, coeff=None, fan_out=1, out=None): |
| | N = sorted_expert_idxs.size(0) |
| | K = A.size(1) |
| | assert A.size(0) * fan_out == N |
| | if out is not None: |
| | Y = out |
| | else: |
| | Y = torch.empty((N, K), dtype=A.dtype, device=A.device) |
| | group_compileable(A, K, N, Y, coeff, coeff is not None, fan_out, sorted_expert_idxs) |
| | return Y |
| |
|
| |
|
| | @torch.library.custom_op("scattermoe::group", mutates_args={"Y"}) |
| | def group_compileable( |
| | A: torch.Tensor, |
| | K: int, |
| | N: int, |
| | Y: torch.Tensor, |
| | coeff: torch.Tensor, has_coeff: bool, |
| | fan_out: int, |
| | sorted_expert_idxs: torch.Tensor) -> None: |
| | def grid(META): |
| | grid_num = (triton.cdiv(META['N'], META['BLOCK_N']),) |
| | return grid_num |
| | _group[grid]( |
| | |
| | A, A.stride(0), A.stride(1), has_coeff, coeff, fan_out, |
| | |
| | Y, Y.stride(0), Y.stride(1), |
| | |
| | sorted_expert_idxs, |
| | |
| | N, K |
| | ) |
| |
|
| |
|
| | @triton.autotune(configs=_config_grouping(), key=['K']) |
| | @triton.heuristics({ |
| | "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0 |
| | }) |
| | @triton.jit |
| | def _group( |
| | src_ptr, stride_sn, stride_sk, has_coeff: tl.constexpr, coeff_ptr, FAN_OUT: tl.constexpr, |
| | tgt_ptr, stride_tn, stride_ti, |
| | grouped_idx_ptr, |
| | N, K: tl.constexpr, |
| | BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, |
| | NO_K_MASK: tl.constexpr |
| | ): |
| | pid = tl.program_id(axis=0) |
| |
|
| | N_block_id = pid |
| | N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) |
| | N_mask = N_blk < N |
| | N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N) |
| | N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0) |
| |
|
| | K_blk = tl.arange(0, BLOCK_K) |
| | src_blk_ptrs = src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk |
| | tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti |
| |
|
| | if has_coeff: |
| | c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None] |
| |
|
| | iters = tl.cdiv(K, BLOCK_K) |
| | for i in range(0, iters): |
| | if NO_K_MASK or i < iters - 1: |
| | block = tl.load(src_blk_ptrs, mask=N_mask[:, None]) |
| | if has_coeff: |
| | block *= c |
| | tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None]) |
| |
|
| | else: |
| | K_mask = (i * BLOCK_K + K_blk) < K |
| | mask = N_mask[:, None] & K_mask[None, :] |
| | block = tl.load(src_blk_ptrs, mask=mask) |
| | if has_coeff: |
| | block *= c |
| | tl.store(tgt_blk_ptrs, block, mask=mask) |
| | src_blk_ptrs += BLOCK_K * stride_sk |
| | tgt_blk_ptrs += BLOCK_K * stride_ti |
| |
|