|
|
import torch |
|
|
import triton |
|
|
import triton.language as tl |
|
|
from typing import Optional |
|
|
|
|
|
BLOCK_M = 128 |
|
|
ALLOW_TF32 = True |
|
|
|
|
|
|
|
|
|
|
|
@triton.jit |
|
|
def _compute_expert_block( |
|
|
E_idx, E_mask, |
|
|
M_in_idx, |
|
|
N_block, N_mask, |
|
|
X_ptr, stride_xm, stride_xk, |
|
|
W_ptr, stride_we, stride_wk, stride_wn, |
|
|
K, |
|
|
acc, |
|
|
no_k_mask, |
|
|
BLOCK_K, |
|
|
allow_tf32=True, |
|
|
): |
|
|
|
|
|
K_block = tl.arange(0, BLOCK_K) |
|
|
X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk |
|
|
W_blk_ptrs = W_ptr + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn + E_idx * stride_we |
|
|
iters = tl.cdiv(K, BLOCK_K) |
|
|
|
|
|
for K_block_id in range(iters): |
|
|
if no_k_mask: |
|
|
x = tl.load(X_blk_ptrs, mask=E_mask[:, None]) |
|
|
w = tl.load(W_blk_ptrs, mask=N_mask[None, :]) |
|
|
else: |
|
|
K_mask = (K_block_id * BLOCK_K + K_block) < K |
|
|
x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :]) |
|
|
w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :]) |
|
|
|
|
|
X_blk_ptrs += BLOCK_K * stride_xk |
|
|
W_blk_ptrs += BLOCK_K * stride_wk |
|
|
acc = tl.dot(x, w, acc, allow_tf32=allow_tf32) |
|
|
return acc |
|
|
|
|
|
|
|
|
def _scatter2scatter_configs(): |
|
|
return [ |
|
|
triton.Config({'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4), |
|
|
] |
|
|
|
|
|
@triton.autotune(configs=_scatter2scatter_configs(), key=['M', 'N', 'K'], ) |
|
|
@triton.heuristics({ |
|
|
"NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0, |
|
|
"NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0, |
|
|
}) |
|
|
@triton.jit |
|
|
def _scatter2scatter( |
|
|
X_ptr, stride_xm: tl.constexpr, stride_xk: tl.constexpr, |
|
|
W_ptr, stride_we, stride_wk: tl.constexpr, stride_wn: tl.constexpr, |
|
|
Y_ptr, stride_ym: tl.constexpr, stride_yn: tl.constexpr, |
|
|
B_ptr, stride_be: tl.constexpr, stride_bn: tl.constexpr, |
|
|
grouped_idx_ptr, expert_idxs_ptr, |
|
|
|
|
|
FAN_OUT: tl.constexpr, |
|
|
M, K: tl.constexpr, N: tl.constexpr, E: tl.constexpr, |
|
|
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, |
|
|
ACC_TYPE: tl.constexpr, |
|
|
|
|
|
allow_tf32: tl.constexpr, |
|
|
x_grouped: tl.constexpr, y_grouped: tl.constexpr, |
|
|
NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr |
|
|
): |
|
|
pid = tl.program_id(axis=0) |
|
|
|
|
|
N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N) |
|
|
M_block_id = pid // N_BLOCK_COUNT |
|
|
N_block_id = pid % N_BLOCK_COUNT |
|
|
|
|
|
M_block = M_block_id * BLOCK_M + tl.arange(0, BLOCK_M) |
|
|
N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) |
|
|
N_mask = N_block < N |
|
|
M_boundary_mask = M_block < (FAN_OUT * M) |
|
|
E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_boundary_mask, other=E) |
|
|
|
|
|
no_k_mask = K % BLOCK_K == 0 |
|
|
|
|
|
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) |
|
|
E_first_idx = tl.min(E_idxs) |
|
|
E_last_idx = tl.minimum(tl.max(E_idxs), E - 1) |
|
|
M_idx = tl.load(grouped_idx_ptr + M_block, mask=M_boundary_mask).to(tl.int32) |
|
|
for E_idx in range(E_first_idx, E_last_idx + 1): |
|
|
E_mask = E_idxs == E_idx |
|
|
E_M_idx = M_idx |
|
|
if x_grouped: |
|
|
M_in_idx = M_block |
|
|
else: |
|
|
M_in_idx = E_M_idx // FAN_OUT |
|
|
acc = _compute_expert_block( |
|
|
E_idx, E_mask, |
|
|
M_in_idx, N_block, N_mask, |
|
|
X_ptr, stride_xm, stride_xk, |
|
|
W_ptr, stride_we, stride_wk, stride_wn, |
|
|
K, |
|
|
acc, |
|
|
no_k_mask, |
|
|
BLOCK_K, |
|
|
allow_tf32=allow_tf32, |
|
|
) |
|
|
|
|
|
if B_ptr is not None: |
|
|
B_blk_ptrs = B_ptr + E_idxs[:, None] * stride_be + N_block[None, :] * stride_bn |
|
|
acc += tl.load(B_blk_ptrs, mask=M_boundary_mask[:, None] & N_mask[None, :]) |
|
|
|
|
|
if y_grouped: |
|
|
M_out_idx = M_block |
|
|
else: |
|
|
M_out_idx = M_idx |
|
|
Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn) |
|
|
tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :]) |
|
|
|
|
|
def scatter2scatter(X, W, sorted_expert_idxs, sorted_scattered_idxs, k, |
|
|
b=None, |
|
|
x_grouped=False, y_grouped=False, |
|
|
out=None): |
|
|
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) |
|
|
assert sorted_scattered_idxs.size(0) == X.size(0) * k |
|
|
|
|
|
y_dim = W.size(-1) |
|
|
L_scattered = sorted_expert_idxs.size(0) |
|
|
if out is None: |
|
|
output = torch.empty((L_scattered, y_dim), device=X.device, dtype=X.dtype) |
|
|
else: |
|
|
assert out.size(0) == L_scattered and out.size(1) == y_dim |
|
|
output = out |
|
|
|
|
|
scatter2scatter_compileable(output, W, X, k, sorted_expert_idxs, sorted_scattered_idxs, |
|
|
b, x_grouped, y_grouped) |
|
|
return output |
|
|
|
|
|
|
|
|
@torch.library.custom_op("scattermoe::scatter2scatter", mutates_args={"output"}) |
|
|
def scatter2scatter_compileable( |
|
|
output: torch.Tensor, |
|
|
W: torch.Tensor, |
|
|
X: torch.Tensor, |
|
|
k: int, |
|
|
sorted_expert_idxs: torch.Tensor, |
|
|
sorted_scattered_idxs: torch.Tensor, |
|
|
b: Optional[torch.Tensor], |
|
|
x_grouped: bool, y_grouped: bool) -> None: |
|
|
def grid(META): |
|
|
grid_num = ( |
|
|
triton.cdiv(sorted_expert_idxs.size(0), META["BLOCK_M"]) * |
|
|
triton.cdiv(META['N'], META['BLOCK_N']), |
|
|
) |
|
|
return grid_num |
|
|
|
|
|
if b is None: |
|
|
b = None |
|
|
stride_be = stride_bk = 0 |
|
|
else: |
|
|
stride_be, stride_bk = b.stride() |
|
|
|
|
|
_scatter2scatter[grid]( |
|
|
|
|
|
X, X.stride(0), X.stride(1), |
|
|
|
|
|
W, W.stride(0), W.stride(1), W.stride(2), |
|
|
|
|
|
output, output.stride(0), output.stride(1), |
|
|
|
|
|
b, stride_be, stride_bk, |
|
|
grouped_idx_ptr=sorted_scattered_idxs, |
|
|
expert_idxs_ptr=sorted_expert_idxs, |
|
|
|
|
|
FAN_OUT=k, |
|
|
M=X.size(0), |
|
|
K=X.size(1), |
|
|
N=output.size(1), E=W.size(0), |
|
|
BLOCK_M=BLOCK_M, |
|
|
ACC_TYPE=tl.float32, |
|
|
allow_tf32=ALLOW_TF32, |
|
|
x_grouped=x_grouped, y_grouped=y_grouped, |
|
|
) |
|
|
|
|
|
|
|
|
def _config_XtY(): |
|
|
return [ |
|
|
triton.Config({'BLOCK_N': 128, 'BLOCK_K': 128, 'BLOCK_M': 32}, num_stages=4, num_warps=4), |
|
|
] |
|
|
|
|
|
def group_bwd_W(DY, X, expert_offsets, E, has_bias=False): |
|
|
DWt = torch.zeros((E, DY.size(-1), X.size(-1)), device=DY.device, dtype=DY.dtype) |
|
|
DW = DWt.permute(0, 2, 1) |
|
|
if has_bias: |
|
|
Db = torch.zeros((E, DY.size(-1)), device=DY.device, dtype=DY.dtype) |
|
|
else: |
|
|
Db = None |
|
|
groupXtY_compileable(E, DW, Db, DY, X, expert_offsets) |
|
|
return DW, Db |
|
|
|
|
|
|
|
|
@torch.library.custom_op("scattermoe::groupXtY", mutates_args={"DW"}) |
|
|
def groupXtY_compileable( |
|
|
E: int, |
|
|
DW: torch.Tensor, |
|
|
Db: Optional[torch.Tensor], |
|
|
DY: torch.Tensor, |
|
|
X: torch.Tensor, |
|
|
expert_offsets: torch.Tensor) -> None: |
|
|
def grid(META): |
|
|
grid = ( |
|
|
E * triton.cdiv(META['K'], META['BLOCK_K']), |
|
|
triton.cdiv(META['N'], META['BLOCK_N']), |
|
|
) |
|
|
return grid |
|
|
|
|
|
if Db is None: |
|
|
stride_dbe = 0 |
|
|
stride_dbn = 0 |
|
|
else: |
|
|
stride_dbe, stride_dbn = Db.stride() |
|
|
|
|
|
_groupXtY[grid]( |
|
|
|
|
|
DY, DY.stride(0), DY.stride(1), |
|
|
|
|
|
X, X.stride(0), X.stride(1), |
|
|
|
|
|
DW, DW.stride(0), DW.stride(1), DW.stride(2), |
|
|
|
|
|
Db, stride_dbe, stride_dbn, |
|
|
|
|
|
expert_offsets, |
|
|
|
|
|
M=DY.size(0), N=DY.size(-1), K=X.size(-1), |
|
|
|
|
|
ACC_TYPE=tl.float32, |
|
|
allow_tf32=ALLOW_TF32 |
|
|
) |
|
|
|
|
|
|
|
|
@triton.autotune(configs=_config_XtY(), key=['M', 'N', 'K'], ) |
|
|
@triton.heuristics({ |
|
|
"NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0, |
|
|
"NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0, |
|
|
}) |
|
|
@triton.jit |
|
|
def _groupXtY( |
|
|
DY_ptr, stride_dym, stride_dyk, |
|
|
X_ptr, stride_xm, stride_xn, |
|
|
DW_ptr, stride_dwe, stride_dwk, stride_dwn, |
|
|
Db_ptr, stride_dbe, stride_dbn, |
|
|
expert_offsets_ptr, |
|
|
M, K: tl.constexpr, N: tl.constexpr, |
|
|
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, |
|
|
ACC_TYPE: tl.constexpr, |
|
|
allow_tf32: tl.constexpr, |
|
|
NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr |
|
|
): |
|
|
pid0 = tl.program_id(axis=0) |
|
|
pid1 = tl.program_id(axis=1) |
|
|
num0 = tl.num_programs(0) |
|
|
num1 = tl.num_programs(1) |
|
|
|
|
|
pid0, pid1 = tl.swizzle2d(pid0, pid1, num0, num1, 4) |
|
|
|
|
|
K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K) |
|
|
E_idx = pid0 // K_BLOCK_COUNT |
|
|
K_block_id = pid0 % K_BLOCK_COUNT |
|
|
N_block_id = pid1 |
|
|
|
|
|
if E_idx == 0: |
|
|
start_idx = 0 |
|
|
else: |
|
|
start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32) |
|
|
end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32) |
|
|
|
|
|
|
|
|
if end_idx > start_idx: |
|
|
M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M) |
|
|
|
|
|
K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K) |
|
|
K_mask = K_block < K |
|
|
K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K) |
|
|
|
|
|
N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) |
|
|
N_mask = N_block < N |
|
|
N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N) |
|
|
|
|
|
M_idxs = M_block |
|
|
xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm |
|
|
dy_blk_ptrs = DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk |
|
|
if (Db_ptr is not None) and (K_block_id == 0): |
|
|
_xty_and_bias( |
|
|
E_idx, start_idx, end_idx, |
|
|
M_block, |
|
|
K_block, K_mask, N_block, N_mask, |
|
|
dy_blk_ptrs, stride_dym, |
|
|
xt_blk_ptrs, stride_xm, |
|
|
DW_ptr, stride_dwe, stride_dwk, stride_dwn, |
|
|
Db_ptr, stride_dbe, stride_dbn, |
|
|
BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE, |
|
|
allow_tf32, NO_K_MASK, NO_N_MASK, |
|
|
compute_bias=True |
|
|
) |
|
|
else: |
|
|
_xty_and_bias( |
|
|
E_idx, start_idx, end_idx, |
|
|
M_block, |
|
|
K_block, K_mask, N_block, N_mask, |
|
|
dy_blk_ptrs, stride_dym, |
|
|
xt_blk_ptrs, stride_xm, |
|
|
DW_ptr, stride_dwe, stride_dwk, stride_dwn, |
|
|
Db_ptr, stride_dbe, stride_dbn, |
|
|
BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE, |
|
|
allow_tf32, NO_K_MASK, NO_N_MASK, |
|
|
compute_bias=False |
|
|
) |
|
|
|
|
|
|
|
|
@triton.jit |
|
|
def _xty_and_bias( |
|
|
E_idx, start_idx, end_idx, |
|
|
M_block, |
|
|
K_block, K_mask, N_block, N_mask, |
|
|
dy_blk_ptrs, stride_dym, |
|
|
xt_blk_ptrs, stride_xm, |
|
|
DW_ptr, stride_dwe, stride_dwk, stride_dwn, |
|
|
Db_ptr, stride_dbe, stride_dbn, |
|
|
BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE, |
|
|
allow_tf32, NO_K_MASK, NO_N_MASK, |
|
|
compute_bias: tl.constexpr |
|
|
): |
|
|
|
|
|
if compute_bias: |
|
|
db_acc = tl.zeros((BLOCK_N,), dtype=ACC_TYPE) |
|
|
else: |
|
|
db_acc = None |
|
|
|
|
|
acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE) |
|
|
iters = tl.cdiv(end_idx - start_idx, BLOCK_M) |
|
|
for i in range(0, iters): |
|
|
M_mask = (i * BLOCK_M + M_block) < end_idx |
|
|
if NO_K_MASK: |
|
|
xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :]) |
|
|
else: |
|
|
xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :]) |
|
|
if NO_N_MASK: |
|
|
dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None]) |
|
|
else: |
|
|
dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :]) |
|
|
|
|
|
acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32) |
|
|
|
|
|
xt_blk_ptrs += BLOCK_M * stride_xm |
|
|
dy_blk_ptrs += BLOCK_M * stride_dym |
|
|
|
|
|
if compute_bias: |
|
|
db_acc += tl.sum(dy, axis=0) |
|
|
|
|
|
DW_blk_ptrs = DW_ptr + E_idx * stride_dwe + K_block[:, None] * stride_dwk + N_block[None, :] * stride_dwn |
|
|
acc = acc.to(DW_blk_ptrs.dtype.element_ty) |
|
|
tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :]) |
|
|
if compute_bias: |
|
|
Db_blk_ptrs = Db_ptr + E_idx * stride_dbe + N_block * stride_dbn |
|
|
tl.store(Db_blk_ptrs, db_acc, mask=N_mask) |
|
|
|
|
|
|
|
|
|
|
|
def _config_grouping(): |
|
|
return [ |
|
|
triton.Config({'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=4, num_warps=4), |
|
|
|
|
|
|
|
|
] |
|
|
|
|
|
def group(A, sorted_expert_idxs, coeff=None, fan_out=1, out=None): |
|
|
N = sorted_expert_idxs.size(0) |
|
|
K = A.size(1) |
|
|
assert A.size(0) * fan_out == N |
|
|
if out is not None: |
|
|
Y = out |
|
|
else: |
|
|
Y = torch.empty((N, K), dtype=A.dtype, device=A.device) |
|
|
group_compileable(A, K, N, Y, coeff, coeff is not None, fan_out, sorted_expert_idxs) |
|
|
return Y |
|
|
|
|
|
|
|
|
@torch.library.custom_op("scattermoe::group", mutates_args={"Y"}) |
|
|
def group_compileable( |
|
|
A: torch.Tensor, |
|
|
K: int, |
|
|
N: int, |
|
|
Y: torch.Tensor, |
|
|
coeff: torch.Tensor, has_coeff: bool, |
|
|
fan_out: int, |
|
|
sorted_expert_idxs: torch.Tensor) -> None: |
|
|
def grid(META): |
|
|
grid_num = (triton.cdiv(META['N'], META['BLOCK_N']),) |
|
|
return grid_num |
|
|
_group[grid]( |
|
|
|
|
|
A, A.stride(0), A.stride(1), has_coeff, coeff, fan_out, |
|
|
|
|
|
Y, Y.stride(0), Y.stride(1), |
|
|
|
|
|
sorted_expert_idxs, |
|
|
|
|
|
N, K |
|
|
) |
|
|
|
|
|
|
|
|
@triton.autotune(configs=_config_grouping(), key=['K']) |
|
|
@triton.heuristics({ |
|
|
"NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0 |
|
|
}) |
|
|
@triton.jit |
|
|
def _group( |
|
|
src_ptr, stride_sn, stride_sk, has_coeff: tl.constexpr, coeff_ptr, FAN_OUT: tl.constexpr, |
|
|
tgt_ptr, stride_tn, stride_ti, |
|
|
grouped_idx_ptr, |
|
|
N, K: tl.constexpr, |
|
|
BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, |
|
|
NO_K_MASK: tl.constexpr |
|
|
): |
|
|
pid = tl.program_id(axis=0) |
|
|
|
|
|
N_block_id = pid |
|
|
N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) |
|
|
N_mask = N_blk < N |
|
|
N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N) |
|
|
N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0) |
|
|
|
|
|
K_blk = tl.arange(0, BLOCK_K) |
|
|
src_blk_ptrs = src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk |
|
|
tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti |
|
|
|
|
|
if has_coeff: |
|
|
c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None] |
|
|
|
|
|
iters = tl.cdiv(K, BLOCK_K) |
|
|
for i in range(0, iters): |
|
|
if NO_K_MASK or i < iters - 1: |
|
|
block = tl.load(src_blk_ptrs, mask=N_mask[:, None]) |
|
|
if has_coeff: |
|
|
block *= c |
|
|
tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None]) |
|
|
|
|
|
else: |
|
|
K_mask = (i * BLOCK_K + K_blk) < K |
|
|
mask = N_mask[:, None] & K_mask[None, :] |
|
|
block = tl.load(src_blk_ptrs, mask=mask) |
|
|
if has_coeff: |
|
|
block *= c |
|
|
tl.store(tgt_blk_ptrs, block, mask=mask) |
|
|
src_blk_ptrs += BLOCK_K * stride_sk |
|
|
tgt_blk_ptrs += BLOCK_K * stride_ti |
|
|
|