File size: 6,204 Bytes
a402b9b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 | import itertools
import math
import unittest
# TODO: use interface in cpu.py
import torch
from utils import (
BLOCK_K,
BLOCK_N,
SiluAndMul,
factor_for_scale,
fp8_max,
fp8_min,
per_token_quant_int8,
precision,
scaled_weight,
torch_naive_moe,
torch_w8a8_per_column_moe,
)
from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler
from sglang.test.test_utils import CustomTestCase
torch.manual_seed(1234)
class TestSharedExpert(CustomTestCase):
M = [2, 121]
N = [32, 32 * 4]
K = [32, 32 * 2]
routed_scaling_factor = [16]
M_fp8 = [2, 12]
N_fp8 = [512]
K_fp8 = [256]
def _bf16_shared_expert(self, m, n, k, routed_scaling_factor):
dtype = torch.bfloat16
prepack = True
hidden_states = torch.randn(m, k, dtype=dtype) / k
w1 = torch.randn(2 * n, k, dtype=dtype)
w2 = torch.randn(k, n, dtype=dtype)
fused_output = torch.randn(m, k, dtype=dtype) / k
# fused moe mutates content in hs
hidden_states2 = hidden_states.clone()
# bfloat16
ref = torch_naive_moe(
hidden_states.float(),
w1.float(),
w2.float(),
fused_output.float(),
routed_scaling_factor,
).to(dtype=dtype)
res = torch.ops.sgl_kernel.shared_expert_cpu(
hidden_states,
w1,
w2,
fused_output,
routed_scaling_factor,
True,
False,
False,
None,
None,
None,
False,
)
atol = rtol = precision[ref.dtype]
torch.testing.assert_close(ref, res, atol=atol, rtol=rtol)
def test_bf16_shared_expert(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.routed_scaling_factor,
):
with self.subTest(
m=params[0],
n=params[1],
k=params[2],
routed_scaling_factor=params[3],
):
self._bf16_shared_expert(*params)
def _int8_shared_expert(self, m, n, k, routed_scaling_factor):
dtype = torch.bfloat16
prepack = True
hidden_states = torch.randn(m, k, dtype=dtype) / k
w1 = torch.randn(2 * n, k, dtype=dtype)
w2 = torch.randn(k, n, dtype=dtype)
fused_output = torch.randn(m, k, dtype=dtype) / k
# fused moe mutates content in hs
hidden_states2 = hidden_states.clone()
w1_q, w1_s = per_token_quant_int8(w1)
w2_q, w2_s = per_token_quant_int8(w2)
ref2 = torch_w8a8_per_column_moe(
hidden_states2.float(),
w1_q,
w2_q,
w1_s,
w2_s,
fused_output.float(),
routed_scaling_factor,
).to(dtype=dtype)
res2 = torch.ops.sgl_kernel.shared_expert_cpu(
hidden_states2,
w1_q,
w2_q,
fused_output,
routed_scaling_factor,
True,
True,
False,
w1_s,
w2_s,
None,
False,
)
atol = rtol = precision[ref2.dtype]
torch.testing.assert_close(ref2, res2, atol=atol, rtol=rtol)
def test_int8_shared_expert(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.routed_scaling_factor,
):
with self.subTest(
m=params[0],
n=params[1],
k=params[2],
routed_scaling_factor=params[3],
):
self._int8_shared_expert(*params)
def _fp8_shared_expert(self, M, N, K, routed_scaling_factor):
set_global_server_args_for_scheduler(ServerArgs(model_path="dummy"))
dtype = torch.bfloat16
prepack = True
a = torch.randn(M, K, dtype=dtype) / math.sqrt(K)
w1_fp32 = torch.randn(1, 2 * N, K)
w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
w2_fp32 = torch.randn(1, K, N)
w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
w1s = torch.randn(1, 2 * N // BLOCK_N, K // BLOCK_K) * factor_for_scale
w2s = torch.randn(1, K // BLOCK_N, N // BLOCK_K) * factor_for_scale
w1_scaled = scaled_weight(w1, w1s).view(2 * N, K)
w2_scaled = scaled_weight(w2, w2s).view(K, N)
# change back to 2D
w1, w2 = w1.squeeze(0), w2.squeeze(0)
w1s, w2s = w1s.squeeze(0), w2s.squeeze(0)
w1_scaled, w2_scaled = w1_scaled.squeeze(0), w2_scaled.squeeze(0)
fused_out = torch.randn(M, K, dtype=dtype) / math.sqrt(K)
a2 = a.clone()
# ref
ic0 = torch.matmul(a.float(), w1_scaled.transpose(0, 1))
ic1 = SiluAndMul(ic0)
shared_out = torch.matmul(ic1, w2_scaled.transpose(0, 1))
ref_out = shared_out + fused_out.float() * routed_scaling_factor
ref_out = ref_out.to(dtype=dtype)
w1 = torch.ops.sgl_kernel.convert_weight_packed(w1) # [2N, K]
w2 = torch.ops.sgl_kernel.convert_weight_packed(w2) # [K, N]
out = torch.ops.sgl_kernel.shared_expert_cpu(
a2,
w1,
w2,
fused_out,
routed_scaling_factor,
True,
False,
True,
w1s,
w2s,
[BLOCK_N, BLOCK_K],
True,
)
atol = rtol = precision[ref_out.dtype]
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
def test_fp8_shared_expert(self):
for params in itertools.product(
self.M_fp8,
self.N_fp8,
self.K_fp8,
self.routed_scaling_factor,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
routed_scaling_factor=params[3],
):
self._fp8_shared_expert(*params)
if __name__ == "__main__":
unittest.main()
|