| import logging | |
| import os | |
| from contextlib import contextmanager | |
| from enum import IntEnum, auto | |
| from typing import Dict, List, Tuple | |
| import torch | |
| from tqdm import tqdm | |
| from sglang.srt.environ import envs | |
| from sglang.srt.layers.deep_gemm_wrapper.configurer import ENABLE_JIT_DEEPGEMM | |
| from sglang.srt.server_args import ServerArgs | |
| from sglang.srt.utils import ceil_div, get_bool_env_var | |
| logger = logging.getLogger(__name__) | |
| if ENABLE_JIT_DEEPGEMM: | |
| import deep_gemm | |
| _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1)) | |
| _ENABLE_JIT_DEEPGEMM_PRECOMPILE = envs.SGLANG_JIT_DEEPGEMM_PRECOMPILE.get() | |
| _DO_COMPILE_ALL = True | |
| _IS_FIRST_RANK_ON_NODE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true") | |
| _IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false") | |
| # Force redirect deep_gemm cache_dir | |
| os.environ["DG_JIT_CACHE_DIR"] = os.getenv( | |
| "SGL_DG_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "deep_gemm") | |
| ) | |
| # Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f | |
| # NVRTC may have performance loss with some cases. | |
| # And NVCC JIT speed is also 9x faster in the ref commit | |
| os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", "0") | |
| def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs): | |
| global _BUILTIN_M_LIST | |
| global _DO_COMPILE_ALL | |
| global _IS_FIRST_RANK_ON_NODE | |
| # Generate m_max | |
| m_max = 1024 * 16 | |
| if server_args.chunked_prefill_size < 1: | |
| m_max = 1024 * 64 | |
| elif server_args.chunked_prefill_size > 8192: | |
| m_max = server_args.chunked_prefill_size * 2 | |
| m_max = min(1024 * 128, m_max) | |
| _BUILTIN_M_LIST = list(range(1, m_max + 1)) | |
| _IS_FIRST_RANK_ON_NODE = ServerArgs.base_gpu_id == gpu_id | |
| # Check if is the first rank on node. | |
| # Default each rank will try compile all Ms to | |
| # load all symbols at the launch stages. | |
| # Avoid loading symbols at the serving stages. | |
| _DO_COMPILE_ALL = _IS_FIRST_RANK_ON_NODE | |
| class DeepGemmKernelType(IntEnum): | |
| GROUPED_GEMM_NT_F8F8BF16_MASKED = auto() | |
| GROUPED_GEMM_NT_F8F8BF16_CONTIG = auto() | |
| GEMM_NT_F8F8BF16 = auto() | |
| _INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dict() | |
| # TODO improve code | |
| def _maybe_compile_deep_gemm_one_type_all( | |
| kernel_type: DeepGemmKernelType, | |
| n: int, | |
| k: int, | |
| num_groups: int, | |
| ) -> None: | |
| global _INITIALIZATION_DICT | |
| global _BUILTIN_M_LIST | |
| query_key = (kernel_type, n, k, num_groups) | |
| if ( | |
| _ENABLE_JIT_DEEPGEMM_PRECOMPILE | |
| and _DO_COMPILE_ALL | |
| and _INITIALIZATION_DICT.get(query_key) is None | |
| ): | |
| _INITIALIZATION_DICT[query_key] = True | |
| # TODO maybe improve logs | |
| if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE: | |
| logger.warning( | |
| "Entering DeepGEMM JIT Pre-Compile session. " | |
| "It may take a long time (typically 10-20 mins) " | |
| "if you have not run `sglang.compile_deep_gemm`. " | |
| "It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`" | |
| " for pre-compilation to reduce the overhead if you have not run it before. " | |
| "For example: " | |
| "`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`" | |
| ) | |
| logger.info( | |
| f"Try DeepGEMM JIT Compiling for " | |
| f"<{kernel_type.name}> N={n}, K={k}, num_groups={num_groups} with all Ms." | |
| f"{' It only takes a little time (typically 1 sec) if you have run `python3 -m sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else ''}" | |
| ) | |
| _compile_deep_gemm_one_type_all( | |
| kernel_type=kernel_type, | |
| n=n, | |
| k=k, | |
| num_groups=num_groups, | |
| m_list=_BUILTIN_M_LIST, | |
| ) | |
| # NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced | |
| def _compile_deep_gemm_one_type_all( | |
| kernel_type: DeepGemmKernelType, | |
| n: int, | |
| k: int, | |
| num_groups: int, | |
| m_list: List[int], | |
| ) -> None: | |
| if kernel_type == DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG: | |
| m_alignment = deep_gemm.get_mk_alignment_for_contiguous_layout() | |
| m_list = sorted(list(set(m for m in m_list if m % m_alignment == 0))) | |
| executor = _BaseWarmupExecutor.create( | |
| kernel_type, max_m=max(m_list), n=n, k=k, num_groups=num_groups | |
| ) | |
| old_compile_mode = deep_gemm.get_compile_mode() | |
| deep_gemm.set_compile_mode(1) | |
| # TODO can use multi thread | |
| for m in tqdm(m_list, desc=f"DeepGEMM warmup"): | |
| executor.execute(m=m) | |
| deep_gemm.set_compile_mode(old_compile_mode) | |
| # clean up input buffers | |
| torch.cuda.current_stream().synchronize() | |
| del executor | |
| torch.cuda.empty_cache() | |
| class _BaseWarmupExecutor: | |
| def create(kernel_type: DeepGemmKernelType, **kwargs): | |
| return { | |
| DeepGemmKernelType.GEMM_NT_F8F8BF16: _NormalWarmupExecutor, | |
| DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG: _GroupedContWarmupExecutor, | |
| DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: _GroupedMaskedWarmupExecutor, | |
| }[kernel_type](**kwargs) | |
| def execute(self, m): | |
| raise NotImplementedError | |
| def _empty_token_fp8(size): | |
| *dims, k = size | |
| return ( | |
| torch.empty(size, device="cuda", dtype=torch.float8_e4m3fn), | |
| torch.empty( | |
| (*dims, ceil_div(k, _BLOCK_SIZE)), device="cuda", dtype=torch.float32 | |
| ), | |
| ) | |
| def _empty_block_fp8(size): | |
| *dims, n, k = size | |
| return ( | |
| torch.empty(size, device="cuda", dtype=torch.float8_e4m3fn), | |
| torch.empty( | |
| (*dims, ceil_div(n, _BLOCK_SIZE), ceil_div(k, _BLOCK_SIZE)), | |
| device="cuda", | |
| dtype=torch.float32, | |
| ), | |
| ) | |
| _BLOCK_SIZE = 128 | |
| class _NormalWarmupExecutor(_BaseWarmupExecutor): | |
| def __init__(self, max_m: int, n: int, k: int, num_groups: int): | |
| self.lhs_q, self.lhs_s = _empty_token_fp8((max_m, k)) | |
| self.rhs_q, self.rhs_s = _empty_block_fp8((n, k)) | |
| self.out = torch.empty((max_m, n), device="cuda", dtype=torch.bfloat16) | |
| def execute(self, m): | |
| deep_gemm.fp8_gemm_nt( | |
| (self.lhs_q[:m], self.lhs_s[:m]), | |
| (self.rhs_q, self.rhs_s), | |
| self.out[:m], | |
| ) | |
| class _GroupedContWarmupExecutor(_BaseWarmupExecutor): | |
| def __init__(self, max_m: int, n: int, k: int, num_groups: int): | |
| self.lhs_q, self.lhs_s = _empty_token_fp8((max_m, k)) | |
| self.rhs_q, self.rhs_s = _empty_block_fp8((num_groups, n, k)) | |
| self.m_indices = torch.zeros((max_m,), device="cuda", dtype=torch.int32) | |
| self.out = torch.empty((max_m, n), device="cuda", dtype=torch.bfloat16) | |
| def execute(self, m): | |
| deep_gemm.m_grouped_fp8_gemm_nt_contiguous( | |
| (self.lhs_q[:m], self.lhs_s[:m]), | |
| (self.rhs_q, self.rhs_s), | |
| self.out[:m], | |
| m_indices=self.m_indices[:m], | |
| ) | |
| class _GroupedMaskedWarmupExecutor(_BaseWarmupExecutor): | |
| def __init__(self, max_m: int, n: int, k: int, num_groups: int): | |
| self.lhs_q, self.lhs_s = _empty_token_fp8((num_groups, max_m, k)) | |
| self.rhs_q, self.rhs_s = _empty_block_fp8((num_groups, n, k)) | |
| self.masked_m = torch.zeros((num_groups,), device="cuda", dtype=torch.int32) | |
| self.out = torch.empty( | |
| (num_groups, max_m, n), device="cuda", dtype=torch.bfloat16 | |
| ) | |
| def execute(self, m): | |
| deep_gemm.fp8_m_grouped_gemm_nt_masked( | |
| (self.lhs_q, self.lhs_s), | |
| (self.rhs_q, self.rhs_s), | |
| self.out, | |
| masked_m=self.masked_m, | |
| # DeepGEMM uses `expect_m` instead of input shape for `get_best_config` | |
| expected_m=m, | |
| ) | |
| def deep_gemm_execution_hook( | |
| m: int, n: int, k: int, num_groups: int, kernel_type: DeepGemmKernelType | |
| ): | |
| _maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups) | |
| yield | |
Xet Storage Details
- Size:
- 8.07 kB
- Xet hash:
- 28cd72fdb124d122b1232c20cadaba7f8b3f0fe7b7de363d758fb3b670b60a5c
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.