| import itertools |
| import math |
| import unittest |
|
|
| |
| import torch |
| from utils import ( |
| BLOCK_K, |
| BLOCK_N, |
| SiluAndMul, |
| factor_for_scale, |
| fp8_max, |
| fp8_min, |
| per_token_quant_int8, |
| precision, |
| scaled_weight, |
| torch_naive_moe, |
| torch_w8a8_per_column_moe, |
| ) |
|
|
| from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler |
| from sglang.test.test_utils import CustomTestCase |
|
|
| torch.manual_seed(1234) |
|
|
|
|
| class TestSharedExpert(CustomTestCase): |
| M = [2, 121] |
| N = [32, 32 * 4] |
| K = [32, 32 * 2] |
| routed_scaling_factor = [16] |
|
|
| M_fp8 = [2, 12] |
| N_fp8 = [512] |
| K_fp8 = [256] |
|
|
| def _bf16_shared_expert(self, m, n, k, routed_scaling_factor): |
| dtype = torch.bfloat16 |
| prepack = True |
|
|
| hidden_states = torch.randn(m, k, dtype=dtype) / k |
| w1 = torch.randn(2 * n, k, dtype=dtype) |
| w2 = torch.randn(k, n, dtype=dtype) |
| fused_output = torch.randn(m, k, dtype=dtype) / k |
|
|
| |
| hidden_states2 = hidden_states.clone() |
|
|
| |
| ref = torch_naive_moe( |
| hidden_states.float(), |
| w1.float(), |
| w2.float(), |
| fused_output.float(), |
| routed_scaling_factor, |
| ).to(dtype=dtype) |
| res = torch.ops.sgl_kernel.shared_expert_cpu( |
| hidden_states, |
| w1, |
| w2, |
| fused_output, |
| routed_scaling_factor, |
| True, |
| False, |
| False, |
| None, |
| None, |
| None, |
| False, |
| ) |
|
|
| atol = rtol = precision[ref.dtype] |
| torch.testing.assert_close(ref, res, atol=atol, rtol=rtol) |
|
|
| def test_bf16_shared_expert(self): |
| for params in itertools.product( |
| self.M, |
| self.N, |
| self.K, |
| self.routed_scaling_factor, |
| ): |
| with self.subTest( |
| m=params[0], |
| n=params[1], |
| k=params[2], |
| routed_scaling_factor=params[3], |
| ): |
| self._bf16_shared_expert(*params) |
|
|
| def _int8_shared_expert(self, m, n, k, routed_scaling_factor): |
| dtype = torch.bfloat16 |
| prepack = True |
|
|
| hidden_states = torch.randn(m, k, dtype=dtype) / k |
| w1 = torch.randn(2 * n, k, dtype=dtype) |
| w2 = torch.randn(k, n, dtype=dtype) |
| fused_output = torch.randn(m, k, dtype=dtype) / k |
|
|
| |
| hidden_states2 = hidden_states.clone() |
|
|
| w1_q, w1_s = per_token_quant_int8(w1) |
| w2_q, w2_s = per_token_quant_int8(w2) |
| ref2 = torch_w8a8_per_column_moe( |
| hidden_states2.float(), |
| w1_q, |
| w2_q, |
| w1_s, |
| w2_s, |
| fused_output.float(), |
| routed_scaling_factor, |
| ).to(dtype=dtype) |
| res2 = torch.ops.sgl_kernel.shared_expert_cpu( |
| hidden_states2, |
| w1_q, |
| w2_q, |
| fused_output, |
| routed_scaling_factor, |
| True, |
| True, |
| False, |
| w1_s, |
| w2_s, |
| None, |
| False, |
| ) |
|
|
| atol = rtol = precision[ref2.dtype] |
| torch.testing.assert_close(ref2, res2, atol=atol, rtol=rtol) |
|
|
| def test_int8_shared_expert(self): |
| for params in itertools.product( |
| self.M, |
| self.N, |
| self.K, |
| self.routed_scaling_factor, |
| ): |
| with self.subTest( |
| m=params[0], |
| n=params[1], |
| k=params[2], |
| routed_scaling_factor=params[3], |
| ): |
| self._int8_shared_expert(*params) |
|
|
| def _fp8_shared_expert(self, M, N, K, routed_scaling_factor): |
| set_global_server_args_for_scheduler(ServerArgs(model_path="dummy")) |
|
|
| dtype = torch.bfloat16 |
| prepack = True |
|
|
| a = torch.randn(M, K, dtype=dtype) / math.sqrt(K) |
|
|
| w1_fp32 = torch.randn(1, 2 * N, K) |
| w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) |
|
|
| w2_fp32 = torch.randn(1, K, N) |
| w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) |
|
|
| w1s = torch.randn(1, 2 * N // BLOCK_N, K // BLOCK_K) * factor_for_scale |
| w2s = torch.randn(1, K // BLOCK_N, N // BLOCK_K) * factor_for_scale |
|
|
| w1_scaled = scaled_weight(w1, w1s).view(2 * N, K) |
| w2_scaled = scaled_weight(w2, w2s).view(K, N) |
|
|
| |
| w1, w2 = w1.squeeze(0), w2.squeeze(0) |
| w1s, w2s = w1s.squeeze(0), w2s.squeeze(0) |
| w1_scaled, w2_scaled = w1_scaled.squeeze(0), w2_scaled.squeeze(0) |
|
|
| fused_out = torch.randn(M, K, dtype=dtype) / math.sqrt(K) |
| a2 = a.clone() |
|
|
| |
| ic0 = torch.matmul(a.float(), w1_scaled.transpose(0, 1)) |
| ic1 = SiluAndMul(ic0) |
| shared_out = torch.matmul(ic1, w2_scaled.transpose(0, 1)) |
| ref_out = shared_out + fused_out.float() * routed_scaling_factor |
| ref_out = ref_out.to(dtype=dtype) |
|
|
| w1 = torch.ops.sgl_kernel.convert_weight_packed(w1) |
| w2 = torch.ops.sgl_kernel.convert_weight_packed(w2) |
| out = torch.ops.sgl_kernel.shared_expert_cpu( |
| a2, |
| w1, |
| w2, |
| fused_out, |
| routed_scaling_factor, |
| True, |
| False, |
| True, |
| w1s, |
| w2s, |
| [BLOCK_N, BLOCK_K], |
| True, |
| ) |
|
|
| atol = rtol = precision[ref_out.dtype] |
| torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol) |
|
|
| def test_fp8_shared_expert(self): |
| for params in itertools.product( |
| self.M_fp8, |
| self.N_fp8, |
| self.K_fp8, |
| self.routed_scaling_factor, |
| ): |
| with self.subTest( |
| M=params[0], |
| N=params[1], |
| K=params[2], |
| routed_scaling_factor=params[3], |
| ): |
| self._fp8_shared_expert(*params) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|