| import itertools |
| import unittest |
|
|
| import torch |
|
|
| from sglang.srt.layers.activation import SiluAndMul |
| from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe |
| from sglang.srt.layers.moe.topk import TopKConfig, select_experts |
| from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 |
| from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler |
| from sglang.test.ci.ci_register import register_cuda_ci |
| from sglang.test.test_utils import CustomTestCase |
|
|
| register_cuda_ci(est_time=8, suite="stage-b-test-small-1-gpu") |
|
|
|
|
| def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): |
| """Matrix multiplication function that supports per-token input quantization and per-column weight quantization""" |
| A = A.to(torch.float32) |
| B = B.to(torch.float32) |
|
|
| assert A.shape[-1] == B.shape[-1], "Dimension mismatch" |
| assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor" |
|
|
| |
| M = A.numel() // A.shape[-1] |
| B = B.t() |
| N, K = B.shape |
| origin_C_shape = A.shape[:-1] + (K,) |
| A = A.reshape(M, N) |
|
|
| |
| C = torch.matmul(A, B) |
| C = As * C * Bs.view(1, -1) |
|
|
| return C.reshape(origin_C_shape).to(output_dtype) |
|
|
|
|
| def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk): |
| """This function performs fused moe with per-column int8 quantization using native torch.""" |
|
|
| set_global_server_args_for_scheduler(ServerArgs(model_path="dummy")) |
|
|
| B, D = a.shape |
| |
| a_q, a_s = per_token_quant_int8(a) |
| |
| a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) |
| |
| a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) |
|
|
| out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) |
|
|
| |
| score = torch.softmax(score, dim=-1, dtype=torch.float32) |
| topk_weight, topk_ids = torch.topk(score, topk) |
| topk_weight = topk_weight.view(-1) |
| topk_ids = topk_ids.view(-1) |
| |
| for i in range(w1.shape[0]): |
| mask = topk_ids == i |
| if mask.sum(): |
| |
| inter_out = native_w8a8_per_token_matmul( |
| a_q[mask], w1[i], a_s[mask], w1_s[i], output_dtype=a.dtype |
| ) |
| |
| act_out = SiluAndMul().forward_native(inter_out) |
| |
| act_out_q, act_out_s = per_token_quant_int8(act_out) |
|
|
| |
| out[mask] = native_w8a8_per_token_matmul( |
| act_out_q, w2[i], act_out_s, w2_s[i], output_dtype=a.dtype |
| ) |
| |
| return ( |
| out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) |
| ).sum(dim=1) |
|
|
|
|
| class TestW8A8Int8FusedMoE(CustomTestCase): |
| DTYPES = [torch.half, torch.bfloat16] |
| M = [1, 33] |
| N = [128, 1024] |
| K = [256, 4096] |
| E = [8] |
| TOP_KS = [2, 6] |
| BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]] |
| BLOCK_SIZE = [[128, 128]] |
| SEEDS = [0] |
|
|
| @classmethod |
| def setUpClass(cls): |
| if not torch.cuda.is_available(): |
| raise unittest.SkipTest("CUDA is not available") |
| torch.set_default_device("cuda") |
|
|
| def _w8a8_int8_fused_moe(self, M, N, K, E, topk, block_size, dtype, seed): |
| torch.manual_seed(seed) |
| |
| factor_for_scale = 1e-2 |
| int8_max = 127 |
| int8_min = -128 |
|
|
| |
| |
| a = torch.randn((M, K), dtype=dtype) / 10 |
|
|
| |
| w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 |
| w1 = (w1_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8) |
|
|
| w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 |
| w2 = (w2_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8) |
|
|
| |
| w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale |
| w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale |
| score = torch.randn((M, E), dtype=dtype) |
|
|
| with torch.inference_mode(): |
| ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk) |
| topk_output = select_experts( |
| hidden_states=a, |
| router_logits=score, |
| topk_config=TopKConfig(top_k=topk, renormalize=False), |
| ) |
| out = fused_moe( |
| a, |
| w1, |
| w2, |
| topk_output, |
| use_fp8_w8a8=False, |
| use_int8_w8a16=False, |
| use_int8_w8a8=True, |
| per_channel_quant=True, |
| w1_scale=w1_s, |
| w2_scale=w2_s, |
| block_shape=None, |
| ) |
|
|
| |
| self.assertTrue( |
| torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) |
| / torch.mean(torch.abs(ref_out.to(torch.float32))) |
| < 0.05 |
| ) |
|
|
| def test_w8a8_int8_fused_moe(self): |
| for params in itertools.product( |
| self.M, |
| self.N, |
| self.K, |
| self.E, |
| self.TOP_KS, |
| self.BLOCK_SIZE, |
| self.DTYPES, |
| self.SEEDS, |
| ): |
| with self.subTest( |
| M=params[0], |
| N=params[1], |
| K=params[2], |
| E=params[3], |
| topk=params[4], |
| block_size=params[5], |
| dtype=params[6], |
| seed=params[7], |
| ): |
| self._w8a8_int8_fused_moe(*params) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main(verbosity=2) |
|
|