leideng's picture
download
raw
8.07 kB
import logging
import os
from contextlib import contextmanager
from enum import IntEnum, auto
from typing import Dict, List, Tuple
import torch
from tqdm import tqdm
from sglang.srt.environ import envs
from sglang.srt.layers.deep_gemm_wrapper.configurer import ENABLE_JIT_DEEPGEMM
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ceil_div, get_bool_env_var
logger = logging.getLogger(__name__)
if ENABLE_JIT_DEEPGEMM:
import deep_gemm
_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
_ENABLE_JIT_DEEPGEMM_PRECOMPILE = envs.SGLANG_JIT_DEEPGEMM_PRECOMPILE.get()
_DO_COMPILE_ALL = True
_IS_FIRST_RANK_ON_NODE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true")
_IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false")
# Force redirect deep_gemm cache_dir
os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
"SGL_DG_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "deep_gemm")
)
# Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f
# NVRTC may have performance loss with some cases.
# And NVCC JIT speed is also 9x faster in the ref commit
os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", "0")
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
global _BUILTIN_M_LIST
global _DO_COMPILE_ALL
global _IS_FIRST_RANK_ON_NODE
# Generate m_max
m_max = 1024 * 16
if server_args.chunked_prefill_size < 1:
m_max = 1024 * 64
elif server_args.chunked_prefill_size > 8192:
m_max = server_args.chunked_prefill_size * 2
m_max = min(1024 * 128, m_max)
_BUILTIN_M_LIST = list(range(1, m_max + 1))
_IS_FIRST_RANK_ON_NODE = ServerArgs.base_gpu_id == gpu_id
# Check if is the first rank on node.
# Default each rank will try compile all Ms to
# load all symbols at the launch stages.
# Avoid loading symbols at the serving stages.
_DO_COMPILE_ALL = _IS_FIRST_RANK_ON_NODE
class DeepGemmKernelType(IntEnum):
GROUPED_GEMM_NT_F8F8BF16_MASKED = auto()
GROUPED_GEMM_NT_F8F8BF16_CONTIG = auto()
GEMM_NT_F8F8BF16 = auto()
_INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dict()
# TODO improve code
def _maybe_compile_deep_gemm_one_type_all(
kernel_type: DeepGemmKernelType,
n: int,
k: int,
num_groups: int,
) -> None:
global _INITIALIZATION_DICT
global _BUILTIN_M_LIST
query_key = (kernel_type, n, k, num_groups)
if (
_ENABLE_JIT_DEEPGEMM_PRECOMPILE
and _DO_COMPILE_ALL
and _INITIALIZATION_DICT.get(query_key) is None
):
_INITIALIZATION_DICT[query_key] = True
# TODO maybe improve logs
if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
logger.warning(
"Entering DeepGEMM JIT Pre-Compile session. "
"It may take a long time (typically 10-20 mins) "
"if you have not run `sglang.compile_deep_gemm`. "
"It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
" for pre-compilation to reduce the overhead if you have not run it before. "
"For example: "
"`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
)
logger.info(
f"Try DeepGEMM JIT Compiling for "
f"<{kernel_type.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
f"{' It only takes a little time (typically 1 sec) if you have run `python3 -m sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else ''}"
)
_compile_deep_gemm_one_type_all(
kernel_type=kernel_type,
n=n,
k=k,
num_groups=num_groups,
m_list=_BUILTIN_M_LIST,
)
# NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
def _compile_deep_gemm_one_type_all(
kernel_type: DeepGemmKernelType,
n: int,
k: int,
num_groups: int,
m_list: List[int],
) -> None:
if kernel_type == DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG:
m_alignment = deep_gemm.get_mk_alignment_for_contiguous_layout()
m_list = sorted(list(set(m for m in m_list if m % m_alignment == 0)))
executor = _BaseWarmupExecutor.create(
kernel_type, max_m=max(m_list), n=n, k=k, num_groups=num_groups
)
old_compile_mode = deep_gemm.get_compile_mode()
deep_gemm.set_compile_mode(1)
# TODO can use multi thread
for m in tqdm(m_list, desc=f"DeepGEMM warmup"):
executor.execute(m=m)
deep_gemm.set_compile_mode(old_compile_mode)
# clean up input buffers
torch.cuda.current_stream().synchronize()
del executor
torch.cuda.empty_cache()
class _BaseWarmupExecutor:
@staticmethod
def create(kernel_type: DeepGemmKernelType, **kwargs):
return {
DeepGemmKernelType.GEMM_NT_F8F8BF16: _NormalWarmupExecutor,
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG: _GroupedContWarmupExecutor,
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: _GroupedMaskedWarmupExecutor,
}[kernel_type](**kwargs)
def execute(self, m):
raise NotImplementedError
def _empty_token_fp8(size):
*dims, k = size
return (
torch.empty(size, device="cuda", dtype=torch.float8_e4m3fn),
torch.empty(
(*dims, ceil_div(k, _BLOCK_SIZE)), device="cuda", dtype=torch.float32
),
)
def _empty_block_fp8(size):
*dims, n, k = size
return (
torch.empty(size, device="cuda", dtype=torch.float8_e4m3fn),
torch.empty(
(*dims, ceil_div(n, _BLOCK_SIZE), ceil_div(k, _BLOCK_SIZE)),
device="cuda",
dtype=torch.float32,
),
)
_BLOCK_SIZE = 128
class _NormalWarmupExecutor(_BaseWarmupExecutor):
def __init__(self, max_m: int, n: int, k: int, num_groups: int):
self.lhs_q, self.lhs_s = _empty_token_fp8((max_m, k))
self.rhs_q, self.rhs_s = _empty_block_fp8((n, k))
self.out = torch.empty((max_m, n), device="cuda", dtype=torch.bfloat16)
def execute(self, m):
deep_gemm.fp8_gemm_nt(
(self.lhs_q[:m], self.lhs_s[:m]),
(self.rhs_q, self.rhs_s),
self.out[:m],
)
class _GroupedContWarmupExecutor(_BaseWarmupExecutor):
def __init__(self, max_m: int, n: int, k: int, num_groups: int):
self.lhs_q, self.lhs_s = _empty_token_fp8((max_m, k))
self.rhs_q, self.rhs_s = _empty_block_fp8((num_groups, n, k))
self.m_indices = torch.zeros((max_m,), device="cuda", dtype=torch.int32)
self.out = torch.empty((max_m, n), device="cuda", dtype=torch.bfloat16)
def execute(self, m):
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(
(self.lhs_q[:m], self.lhs_s[:m]),
(self.rhs_q, self.rhs_s),
self.out[:m],
m_indices=self.m_indices[:m],
)
class _GroupedMaskedWarmupExecutor(_BaseWarmupExecutor):
def __init__(self, max_m: int, n: int, k: int, num_groups: int):
self.lhs_q, self.lhs_s = _empty_token_fp8((num_groups, max_m, k))
self.rhs_q, self.rhs_s = _empty_block_fp8((num_groups, n, k))
self.masked_m = torch.zeros((num_groups,), device="cuda", dtype=torch.int32)
self.out = torch.empty(
(num_groups, max_m, n), device="cuda", dtype=torch.bfloat16
)
def execute(self, m):
deep_gemm.fp8_m_grouped_gemm_nt_masked(
(self.lhs_q, self.lhs_s),
(self.rhs_q, self.rhs_s),
self.out,
masked_m=self.masked_m,
# DeepGEMM uses `expect_m` instead of input shape for `get_best_config`
expected_m=m,
)
@contextmanager
def deep_gemm_execution_hook(
m: int, n: int, k: int, num_groups: int, kernel_type: DeepGemmKernelType
):
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
yield

Xet Storage Details

Size:
8.07 kB
·
Xet hash:
28cd72fdb124d122b1232c20cadaba7f8b3f0fe7b7de363d758fb3b670b60a5c

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.