leideng's picture
download
raw
3.2 kB
import logging
from contextlib import contextmanager
from typing import Tuple
import torch
from sglang.srt.layers.deep_gemm_wrapper import compile_utils
from sglang.srt.layers.deep_gemm_wrapper.configurer import ( # noqa: F401
DEEPGEMM_BLACKWELL,
DEEPGEMM_SCALE_UE8M0,
ENABLE_JIT_DEEPGEMM,
)
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_bool_env_var
logger = logging.getLogger(__name__)
if ENABLE_JIT_DEEPGEMM:
import deep_gemm
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor # noqa: F401
_SANITY_CHECK = get_bool_env_var("SGLANG_DEEPGEMM_SANITY_CHECK")
# TODO maybe rename these functions
def grouped_gemm_nt_f8f8bf16_masked(
lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor,
masked_m: torch.Tensor,
expected_m: int,
):
num_groups, _, k = lhs[0].shape
_, n, _ = rhs[0].shape
kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED
_sanity_check_input(lhs)
_sanity_check_input(rhs)
with compile_utils.deep_gemm_execution_hook(
expected_m, n, k, num_groups, kernel_type
):
deep_gemm.fp8_m_grouped_gemm_nt_masked(
lhs,
rhs,
out,
masked_m,
expected_m,
)
def grouped_gemm_nt_f8f8bf16_contig(
lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor,
m_indices: torch.Tensor,
):
m, k = lhs[0].shape
num_groups, n, _ = rhs[0].shape
kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
_sanity_check_input(lhs)
_sanity_check_input(rhs)
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(lhs, rhs, out, m_indices)
def gemm_nt_f8f8bf16(
lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor,
):
m, k = lhs[0].shape
n, _ = rhs[0].shape
num_groups = 1
kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_F8F8BF16
_sanity_check_input(lhs)
_sanity_check_input(rhs)
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
deep_gemm.fp8_gemm_nt(
lhs,
rhs,
out,
)
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
compile_utils.update_deep_gemm_config(gpu_id, server_args)
@contextmanager
def configure_deep_gemm_num_sms(num_sms):
if num_sms is None:
yield
else:
original_num_sms = deep_gemm.get_num_sms()
deep_gemm.set_num_sms(num_sms)
try:
yield
finally:
deep_gemm.set_num_sms(original_num_sms)
def _sanity_check_input(x_fp8: Tuple[torch.Tensor, torch.Tensor]):
if not _SANITY_CHECK:
return
x, x_scale = x_fp8
if x_scale.dtype == torch.int:
return
from sglang.srt.layers.quantization.fp8_utils import ceil_to_ue8m0
x_scale_ceil = ceil_to_ue8m0(x_scale)
assert torch.all(x_scale == x_scale_ceil), f"{x_scale=} {x_scale_ceil=}"

Xet Storage Details

Size:
3.2 kB
·
Xet hash:
a0d42a108458f79f891a20bbb739a7acc15ef55414c1c407273743bcba2a3a85

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.