| import os |
|
|
| import torch |
| import triton |
| import triton.language as tl |
| from sgl_kernel import moe_sum_reduce as moe_sum_reduce_cuda |
| from triton.testing import do_bench |
|
|
| |
| IS_CI = ( |
| os.getenv("CI", "false").lower() == "true" |
| or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" |
| ) |
|
|
|
|
| @triton.jit |
| def _moe_sum_reduce_kernel( |
| input_ptr, |
| input_stride_0, |
| input_stride_1, |
| input_stride_2, |
| output_ptr, |
| output_stride_0, |
| output_stride_1, |
| token_num: int, |
| topk_num: int, |
| hidden_dim: int, |
| routed_scaling_factor: tl.constexpr, |
| BLOCK_M: tl.constexpr, |
| BLOCK_DIM: tl.constexpr, |
| NUM_STAGE: tl.constexpr, |
| ): |
| input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64) |
| input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64) |
| output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64) |
|
|
| token_block_id = tl.program_id(0) |
| dim_block_id = tl.program_id(1) |
|
|
| offs_token = token_block_id * BLOCK_M + tl.arange(0, BLOCK_M) |
| offs_dim = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM) |
|
|
| mask_token = offs_token < token_num |
| mask_dim = offs_dim < hidden_dim |
|
|
| base_ptrs = input_ptr + offs_token[:, None] * input_stride_0 + offs_dim[None, :] |
|
|
| accumulator = tl.zeros((BLOCK_M, BLOCK_DIM), dtype=tl.float32) |
| for i in tl.range(0, topk_num, num_stages=NUM_STAGE): |
| tile = tl.load( |
| base_ptrs + i * input_stride_1, |
| mask=mask_token[:, None] & mask_dim[None, :], |
| other=0.0, |
| ) |
| accumulator += tile.to(tl.float32) |
| accumulator *= routed_scaling_factor |
|
|
| |
| store_ptrs = output_ptr + offs_token[:, None] * output_stride_0 + offs_dim[None, :] |
| tl.store( |
| store_ptrs, |
| accumulator.to(input_ptr.dtype.element_ty), |
| mask=mask_token[:, None] & mask_dim[None, :], |
| ) |
|
|
|
|
| |
| def moe_sum_reduce_triton( |
| input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float |
| ): |
| assert input.is_contiguous() |
| assert output.is_contiguous() |
|
|
| token_num, topk_num, hidden_dim = input.shape |
| assert output.shape[0] == token_num and output.shape[1] == hidden_dim |
|
|
| BLOCK_M = 1 |
| BLOCK_DIM = 2048 |
| NUM_STAGE = 1 |
| num_warps = 16 |
|
|
| grid = ( |
| triton.cdiv(token_num, BLOCK_M), |
| triton.cdiv(hidden_dim, BLOCK_DIM), |
| ) |
|
|
| _moe_sum_reduce_kernel[grid]( |
| input, |
| *input.stride(), |
| output, |
| *output.stride(), |
| token_num=token_num, |
| topk_num=topk_num, |
| hidden_dim=hidden_dim, |
| routed_scaling_factor=routed_scaling_factor, |
| BLOCK_M=BLOCK_M, |
| BLOCK_DIM=BLOCK_DIM, |
| NUM_STAGE=NUM_STAGE, |
| num_warps=num_warps, |
| ) |
| return |
|
|
|
|
| def compute_sum_scaled_baseline( |
| x: torch.Tensor, out: torch.Tensor, routed_scaling_factor: float |
| ) -> torch.Tensor: |
| torch.sum(x, dim=1, out=out) |
| out.mul_(routed_scaling_factor) |
| return out |
|
|
|
|
| @torch.compile |
| def compute_sum_scaled_compiled( |
| x: torch.Tensor, out: torch.Tensor, routed_scaling_factor: float |
| ) -> torch.Tensor: |
| torch.sum(x * routed_scaling_factor, dim=1, out=out) |
| return out |
|
|
|
|
| def get_benchmark(dtype=torch.bfloat16): |
| num_tokens_range = [2**i for i in range(0, 13)] |
|
|
| @triton.testing.perf_report( |
| triton.testing.Benchmark( |
| x_names=["num_tokens"], |
| x_vals=num_tokens_range, |
| line_arg="version", |
| line_vals=["baseline", "compiled", "triton", "cuda"], |
| line_names=["Original", "TorchCompile", "TritonKernel", "CudaKernel"], |
| styles=[("blue", "-"), ("green", "-"), ("red", "-"), ("yellow", "-")], |
| ylabel="us", |
| plot_name=f"sum_scaled_performance_{str(dtype).split('.')[-1]}", |
| args={}, |
| ) |
| ) |
| def benchmark(num_tokens, version): |
| topk = 9 |
| hidden_size = 4096 |
| dtype = torch.bfloat16 |
| scaling_factor = 0.3 |
|
|
| x = torch.randn(num_tokens, topk, hidden_size, dtype=dtype, device="cuda") |
| out = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda") |
|
|
| |
| for _ in range(3): |
| if version == "baseline": |
| compute_sum_scaled_baseline(x, out, scaling_factor) |
| elif version == "compiled": |
| compute_sum_scaled_compiled(x, out, scaling_factor) |
| elif version == "triton": |
| moe_sum_reduce_triton(x, out, scaling_factor) |
| else: |
| moe_sum_reduce_cuda(x, out, scaling_factor) |
|
|
| |
| quantiles = [0.5, 0.2, 0.8] |
| if version == "baseline": |
| ms, min_ms, max_ms = do_bench( |
| lambda: compute_sum_scaled_baseline(x, out, scaling_factor), |
| quantiles=quantiles, |
| ) |
| elif version == "compiled": |
| ms, min_ms, max_ms = do_bench( |
| lambda: compute_sum_scaled_compiled(x, out, scaling_factor), |
| quantiles=quantiles, |
| ) |
| elif version == "triton": |
| ms, min_ms, max_ms = do_bench( |
| lambda: moe_sum_reduce_triton(x, out, scaling_factor), |
| quantiles=quantiles, |
| ) |
| else: |
| ms, min_ms, max_ms = do_bench( |
| lambda: moe_sum_reduce_cuda(x, out, scaling_factor), |
| quantiles=quantiles, |
| ) |
|
|
| return 1000 * ms, 1000 * max_ms, 1000 * min_ms |
|
|
| return benchmark |
|
|
|
|
| def verify_correctness(num_tokens=1024, dtype=torch.bfloat16): |
| x = torch.randn(num_tokens, 9, 4096, device="cuda", dtype=dtype) |
| scaling_factor = 0.3 |
|
|
| out_baseline = torch.empty_like(x[:, 0]) |
| compute_sum_scaled_baseline(x, out_baseline, scaling_factor) |
|
|
| out_compiled = torch.empty_like(out_baseline) |
| compute_sum_scaled_compiled(x, out_compiled, scaling_factor) |
|
|
| out_cuda = torch.empty_like(out_baseline) |
| moe_sum_reduce_cuda(x, out_cuda, scaling_factor) |
|
|
| triton_skipped = dtype == torch.float64 |
| if not triton_skipped: |
| out_triton = torch.empty_like(out_baseline) |
| moe_sum_reduce_triton(x, out_triton, scaling_factor) |
|
|
| if dtype == torch.float64: |
| atol, rtol = 1e-12, 1e-12 |
| elif dtype == torch.float32: |
| atol, rtol = 1e-6, 1e-6 |
| else: |
| atol, rtol = 1e-2, 1e-2 |
|
|
| ok_compiled = torch.allclose(out_baseline, out_compiled, atol=atol, rtol=rtol) |
| ok_cuda = torch.allclose(out_baseline, out_cuda, atol=atol, rtol=rtol) |
| ok_triton = ( |
| True |
| if triton_skipped |
| else torch.allclose(out_baseline, out_triton, atol=atol, rtol=rtol) |
| ) |
|
|
| if ok_compiled and ok_triton and ok_cuda: |
| msg = "✅ All implementations match" |
| if triton_skipped: |
| msg += " (Triton skipped for float64)" |
| print(msg) |
| else: |
| print("❌ Implementations differ") |
| print( |
| f"Baseline vs Compiled: {(out_baseline - out_compiled).abs().max().item()}" |
| ) |
| if not triton_skipped: |
| print( |
| f"Baseline vs Triton: {(out_baseline - out_triton).abs().max().item()}" |
| ) |
| print(f"Baseline vs Cuda: {(out_baseline - out_cuda).abs().max().item()}") |
|
|
|
|
| if __name__ == "__main__": |
| print("Running correctness verification for bfloat16...") |
| verify_correctness(dtype=torch.bfloat16) |
|
|
| |
| if not IS_CI: |
| print("Running correctness verification for float64...") |
| verify_correctness(dtype=torch.float64) |
|
|
| print("Running correctness verification for float64...") |
| verify_correctness(dtype=torch.float64) |
|
|
| print("\nRunning performance benchmark for bfloat16...") |
| benchmark = get_benchmark(dtype=torch.bfloat16) |
| benchmark.run( |
| print_data=True, |
| |
| ) |
|
|