| import logging | |
| from contextlib import contextmanager | |
| from typing import Tuple | |
| import torch | |
| from sglang.srt.layers.deep_gemm_wrapper import compile_utils | |
| from sglang.srt.layers.deep_gemm_wrapper.configurer import ( # noqa: F401 | |
| DEEPGEMM_BLACKWELL, | |
| DEEPGEMM_SCALE_UE8M0, | |
| ENABLE_JIT_DEEPGEMM, | |
| ) | |
| from sglang.srt.server_args import ServerArgs | |
| from sglang.srt.utils import get_bool_env_var | |
| logger = logging.getLogger(__name__) | |
| if ENABLE_JIT_DEEPGEMM: | |
| import deep_gemm | |
| from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor # noqa: F401 | |
| _SANITY_CHECK = get_bool_env_var("SGLANG_DEEPGEMM_SANITY_CHECK") | |
| # TODO maybe rename these functions | |
| def grouped_gemm_nt_f8f8bf16_masked( | |
| lhs: Tuple[torch.Tensor, torch.Tensor], | |
| rhs: Tuple[torch.Tensor, torch.Tensor], | |
| out: torch.Tensor, | |
| masked_m: torch.Tensor, | |
| expected_m: int, | |
| ): | |
| num_groups, _, k = lhs[0].shape | |
| _, n, _ = rhs[0].shape | |
| kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED | |
| _sanity_check_input(lhs) | |
| _sanity_check_input(rhs) | |
| with compile_utils.deep_gemm_execution_hook( | |
| expected_m, n, k, num_groups, kernel_type | |
| ): | |
| deep_gemm.fp8_m_grouped_gemm_nt_masked( | |
| lhs, | |
| rhs, | |
| out, | |
| masked_m, | |
| expected_m, | |
| ) | |
| def grouped_gemm_nt_f8f8bf16_contig( | |
| lhs: Tuple[torch.Tensor, torch.Tensor], | |
| rhs: Tuple[torch.Tensor, torch.Tensor], | |
| out: torch.Tensor, | |
| m_indices: torch.Tensor, | |
| ): | |
| m, k = lhs[0].shape | |
| num_groups, n, _ = rhs[0].shape | |
| kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG | |
| _sanity_check_input(lhs) | |
| _sanity_check_input(rhs) | |
| with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type): | |
| deep_gemm.m_grouped_fp8_gemm_nt_contiguous(lhs, rhs, out, m_indices) | |
| def gemm_nt_f8f8bf16( | |
| lhs: Tuple[torch.Tensor, torch.Tensor], | |
| rhs: Tuple[torch.Tensor, torch.Tensor], | |
| out: torch.Tensor, | |
| ): | |
| m, k = lhs[0].shape | |
| n, _ = rhs[0].shape | |
| num_groups = 1 | |
| kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_F8F8BF16 | |
| _sanity_check_input(lhs) | |
| _sanity_check_input(rhs) | |
| with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type): | |
| deep_gemm.fp8_gemm_nt( | |
| lhs, | |
| rhs, | |
| out, | |
| ) | |
| def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs): | |
| compile_utils.update_deep_gemm_config(gpu_id, server_args) | |
| def configure_deep_gemm_num_sms(num_sms): | |
| if num_sms is None: | |
| yield | |
| else: | |
| original_num_sms = deep_gemm.get_num_sms() | |
| deep_gemm.set_num_sms(num_sms) | |
| try: | |
| yield | |
| finally: | |
| deep_gemm.set_num_sms(original_num_sms) | |
| def _sanity_check_input(x_fp8: Tuple[torch.Tensor, torch.Tensor]): | |
| if not _SANITY_CHECK: | |
| return | |
| x, x_scale = x_fp8 | |
| if x_scale.dtype == torch.int: | |
| return | |
| from sglang.srt.layers.quantization.fp8_utils import ceil_to_ue8m0 | |
| x_scale_ceil = ceil_to_ue8m0(x_scale) | |
| assert torch.all(x_scale == x_scale_ceil), f"{x_scale=} {x_scale_ceil=}" | |
Xet Storage Details
- Size:
- 3.2 kB
- Xet hash:
- a0d42a108458f79f891a20bbb739a7acc15ef55414c1c407273743bcba2a3a85
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.