Hanrui / sglang /test /srt /cpu /test_gemm.py
Lekr0's picture
Add files using upload-large-folder tool
a402b9b verified
import itertools
import unittest
# TODO: use interface in cpu.py
import torch
import torch.nn as nn
from utils import (
convert_weight,
native_w8a8_per_token_matmul,
per_token_quant_int8,
precision,
unpack_and_dequant_awq,
)
from sglang.test.test_utils import CustomTestCase
torch.manual_seed(1234)
class Mod(nn.Module):
def __init__(self, input_channel, output_channel, has_bias):
super(Mod, self).__init__()
self.linear = torch.nn.Linear(input_channel, output_channel, has_bias)
def forward(self, x):
return self.linear(x)
class TestGemm(CustomTestCase):
M = [1, 101]
N = [16, 32 * 13]
K = [32 * 16]
has_bias = [False, True]
M_int8 = [2, 128]
N_int8 = [32 * 12]
K_int8 = [32 * 17]
M_fp8 = [1, 11]
N_fp8 = [128, 224]
K_fp8 = [512, 576]
M_awq = [1, 32]
N_awq = [4096]
K_awq = [4096]
def _bf16_gemm(self, M, N, K, has_bias):
mat1 = torch.randn(M, K, dtype=torch.bfloat16)
mat2 = torch.randn(N, K, dtype=torch.bfloat16)
ref = torch.matmul(mat1.float(), mat2.float().t())
if has_bias:
bias = torch.randn(N, dtype=torch.float32)
ref.add_(bias.bfloat16())
ref = ref.bfloat16()
out = torch.ops.sgl_kernel.weight_packed_linear(
mat1, mat2, bias if has_bias else None, False
)
packed_mat2 = torch.ops.sgl_kernel.convert_weight_packed(mat2)
out2 = torch.ops.sgl_kernel.weight_packed_linear(
mat1, packed_mat2, bias if has_bias else None, True
)
atol = rtol = precision[ref.dtype]
torch.testing.assert_close(ref, out, atol=atol, rtol=rtol)
torch.testing.assert_close(ref, out2, atol=atol, rtol=rtol)
def test_bf16_gemm(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.has_bias,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
has_bias=params[3],
):
self._bf16_gemm(*params)
def _bf16_gemm_with_small_oc(self, M, N, K, has_bias, use_post_sigmul):
use_post_sigmul = use_post_sigmul and N == 1
mat_mul = (
None if not use_post_sigmul else torch.randn(M, 2 * K, dtype=torch.bfloat16)
)
mat1 = torch.randn(M, K, dtype=torch.bfloat16)
mat2 = torch.randn(N, K, dtype=torch.bfloat16)
ref = torch.nn.functional.linear(mat1, mat2)
if has_bias:
bias = torch.randn(N, dtype=torch.float32)
ref.add_(bias)
if use_post_sigmul:
ref = torch.nn.functional.sigmoid(ref) * mat_mul
out = torch.ops.sgl_kernel.fused_linear_sigmoid_mul(
mat1,
torch.ops.sgl_kernel.convert_weight_packed(mat2),
bias if has_bias else None,
True,
mat_mul if use_post_sigmul else None,
)
else:
out = torch.ops.sgl_kernel.weight_packed_linear(
mat1,
torch.ops.sgl_kernel.convert_weight_packed(mat2),
bias if has_bias else None,
True,
)
atol = rtol = precision[ref.dtype]
torch.testing.assert_close(ref, out, atol=atol, rtol=rtol)
def test_bf16_gemm_with_small_oc(self):
for params in itertools.product(
[1, 8, 32, 1024], [12, 1], self.K, self.has_bias, [False, True]
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
has_bias=params[3],
use_post_sigmul=params[4],
):
self._bf16_gemm_with_small_oc(*params)
def _int8_gemm(self, M, N, K, has_bias):
dtype = torch.bfloat16
A = torch.randn((M, K), dtype=dtype) / 10
Aq, As = per_token_quant_int8(A)
factor_for_scale = 1e-2
int8_max = 127
int8_min = -128
B = (torch.rand((N, K), dtype=torch.float32) - 0.5) * 2
Bq = (B * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
Bs = torch.rand(N) * factor_for_scale
bias = torch.randn(N) if has_bias else None
ref_out = native_w8a8_per_token_matmul(Aq, Bq, As, Bs, bias, dtype)
atol = rtol = precision[ref_out.dtype]
Aq2, As2 = torch.ops.sgl_kernel.per_token_quant_int8_cpu(A)
out = torch.ops.sgl_kernel.int8_scaled_mm_cpu(
Aq2, Bq, As2, Bs, bias if has_bias else None, torch.bfloat16, False
)
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
# test the fused version
fused_out = torch.ops.sgl_kernel.int8_scaled_mm_with_quant(
A, Bq, Bs, bias if has_bias else None, torch.bfloat16, False
)
torch.testing.assert_close(ref_out, fused_out, atol=atol, rtol=rtol)
def test_int8_gemm(self):
for params in itertools.product(
self.M_int8,
self.N_int8,
self.K_int8,
self.has_bias,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
has_bias=params[3],
):
self._int8_gemm(*params)
def _fp8_gemm(self, M, N, K, has_bias):
prepack = True
chunk = False
scale_block_size_N = 64
scale_block_size_K = 128
assert scale_block_size_N <= N
assert scale_block_size_K <= K
A_dtype = torch.bfloat16
model = Mod(K, N, has_bias).eval()
if chunk:
data = torch.randn(M, K + 6, dtype=A_dtype).narrow(1, 0, K)
else:
data = torch.randn(M, K, dtype=A_dtype)
weight = model.linear.weight # (N, K)
if has_bias:
bias = model.linear.bias
fp8_weight, scales, dq_weight = convert_weight(
weight, [scale_block_size_N, scale_block_size_K], A_dtype
)
if has_bias:
ref = torch.matmul(data.to(A_dtype), dq_weight.T) + bias.to(A_dtype)
else:
ref = torch.matmul(data.to(A_dtype), dq_weight.T)
if prepack:
fp8_weight = torch.ops.sgl_kernel.convert_weight_packed(fp8_weight)
opt = torch.ops.sgl_kernel.fp8_scaled_mm_cpu(
data,
fp8_weight,
scales,
[scale_block_size_N, scale_block_size_K],
bias if has_bias else None,
data.dtype,
prepack,
)
atol = rtol = precision[ref.dtype]
torch.testing.assert_close(ref, opt, atol=atol, rtol=rtol)
def test_fp8_gemm(self):
for params in itertools.product(
self.M_fp8,
self.N_fp8,
self.K_fp8,
self.has_bias,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
has_bias=params[3],
):
self._fp8_gemm(*params)
def _int4_awq_gemm(self, M, N, K, group_size, has_bias):
awq_weight = torch.randint(-128, 128, (K, N // 8)).to(torch.int)
awq_zero = torch.randint(0, 10, (K // group_size, N // 8)).to(torch.int)
awq_scales = torch.rand(int(K // group_size), N).to(torch.bfloat16)
bf16_weight, _ = unpack_and_dequant_awq(
awq_weight, awq_zero, awq_scales, 4, 128
)
if has_bias:
bias = torch.rand(bf16_weight.shape[0]).to(torch.float)
else:
bias = None
x = torch.rand(M, bf16_weight.size(-1)).to(torch.bfloat16)
ref_res = torch.nn.functional.linear(
x, bf16_weight, bias=bias.to(torch.bfloat16) if has_bias else None
)
packed_weight, packed_zero, packed_scales = (
torch.ops.sgl_kernel.convert_weight_packed_scale_zp(
awq_weight, awq_zero, awq_scales
)
)
target_res = torch.ops.sgl_kernel.int4_scaled_mm_cpu(
x,
packed_weight,
packed_zero,
packed_scales,
bias,
)
atol = rtol = precision[ref_res.dtype]
torch.testing.assert_close(ref_res, target_res, atol=atol, rtol=rtol)
def test_int4_awq_gemm(self):
for params in itertools.product(
self.M_awq, self.N_awq, self.K_awq, [128], self.has_bias
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
group_size=params[3],
has_bias=params[4],
):
self._int4_awq_gemm(*params)
if __name__ == "__main__":
unittest.main()