| |
| |
| |
| from __future__ import annotations |
|
|
| import logging |
| from enum import IntEnum |
| from typing import TYPE_CHECKING, Any, Dict, List, Optional |
|
|
| import regex as re |
| import torch |
| from torch.nn.parameter import Parameter |
|
|
| from sglang.srt.distributed import get_tp_group |
| from sglang.srt.distributed.device_communicators.pynccl_allocator import ( |
| use_symmetric_memory, |
| ) |
| from sglang.srt.environ import envs |
| from sglang.srt.layers.dp_attention import is_allocation_symmetric |
| from sglang.srt.layers.moe import ( |
| MoeRunner, |
| MoeRunnerBackend, |
| MoeRunnerConfig, |
| get_moe_a2a_backend, |
| get_moe_runner_backend, |
| ) |
| from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType |
| from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo |
| from sglang.srt.layers.moe.utils import ( |
| is_flashinfer_cutedsl_v1_path, |
| should_use_flashinfer_cutlass_moe_fp4_allgather, |
| ) |
| from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter |
| from sglang.srt.layers.quantization.base_config import ( |
| FusedMoEMethodBase, |
| LinearMethodBase, |
| QuantizationConfig, |
| QuantizeMethodBase, |
| ) |
| from sglang.srt.layers.quantization.fp4_utils import ( |
| fp4_quantize, |
| get_fp4_gemm_runner_backend, |
| ) |
| from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant |
| from sglang.srt.layers.quantization.fp8_utils import ( |
| apply_fp8_linear, |
| cutlass_fp8_supported, |
| is_blackwell_supported, |
| ) |
| from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod |
| from sglang.srt.layers.quantization.marlin_utils_fp4 import ( |
| apply_fp4_marlin_linear, |
| prepare_moe_nvfp4_layer_for_marlin, |
| prepare_nvfp4_layer_for_marlin, |
| ) |
| from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod |
| from sglang.srt.layers.quantization.utils import ( |
| convert_to_channelwise, |
| is_layer_skipped, |
| per_tensor_dequantize, |
| requantize_with_max_scale, |
| swizzle_blockscale, |
| ) |
| from sglang.srt.layers.radix_attention import RadixAttention |
| from sglang.srt.layers.utils import alias_or_bind_derived_param, copy_or_rebind_param |
| from sglang.srt.utils.common import ( |
| get_device_capability, |
| is_cuda, |
| is_sm120_supported, |
| next_power_of_2, |
| round_up, |
| ) |
| from sglang.srt.utils.custom_op import register_custom_op |
| from sglang.srt.utils.patch_torch import register_fake_if_exists |
|
|
| if TYPE_CHECKING: |
| from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE |
| from sglang.srt.layers.moe.token_dispatcher import ( |
| CombineInput, |
| StandardDispatchOutput, |
| ) |
| from sglang.srt.models.utils import WeightsMapper |
|
|
| try: |
| from flashinfer import mm_fp4 as flashinfer_fp4_gemm |
| from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_sf_a |
|
|
| enable_flashinfer_fp4_gemm = True |
| except ImportError: |
| enable_flashinfer_fp4_gemm = False |
| reorder_rows_for_gated_act_gemm = None |
| shuffle_matrix_a = None |
| shuffle_matrix_sf_a = None |
|
|
| if is_cuda(): |
| try: |
| from sglang.jit_kernel.nvfp4 import cutlass_scaled_fp4_mm as cutlass_fp4_gemm |
| except ImportError: |
| cutlass_fp4_gemm = None |
| else: |
| cutlass_fp4_gemm = None |
|
|
| try: |
| from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe |
| from flashinfer.fused_moe.core import ActivationType |
| except ImportError: |
| flashinfer_cutlass_fused_moe = None |
|
|
| |
| class ActivationType(IntEnum): |
| Swiglu = 3 |
| Geglu = 4 |
| Relu2 = 6 |
| Identity = 7 |
|
|
|
|
| |
| logger = logging.getLogger(__name__) |
|
|
|
|
| def _sglang_fp4_gemm_fake( |
| input: torch.Tensor, |
| weight: torch.Tensor, |
| input_sf: torch.Tensor, |
| weight_sf: torch.Tensor, |
| alpha: torch.Tensor, |
| out_dtype: torch.dtype, |
| out_features: int, |
| ) -> torch.Tensor: |
| M = input.shape[-2] |
| N = int(out_features) |
| return input.new_empty((M, N), dtype=out_dtype) |
|
|
|
|
| @register_custom_op(fake_impl=_sglang_fp4_gemm_fake) |
| def fp4_gemm( |
| input: torch.Tensor, |
| weight: torch.Tensor, |
| input_sf: torch.Tensor, |
| weight_sf: torch.Tensor, |
| alpha: torch.Tensor, |
| out_dtype: torch.dtype, |
| out_features: int, |
| ) -> torch.Tensor: |
| fp4_backend = get_fp4_gemm_runner_backend() |
| if fp4_backend.is_cutlass() and cutlass_fp4_gemm is not None: |
| |
| |
| if input_sf.dtype != torch.float8_e4m3fn: |
| input_sf = input_sf.view(torch.float8_e4m3fn) |
| if weight_sf.dtype != torch.float8_e4m3fn: |
| weight_sf = weight_sf.view(torch.float8_e4m3fn) |
| return cutlass_fp4_gemm(input, weight, input_sf, weight_sf, alpha, out_dtype) |
| elif enable_flashinfer_fp4_gemm: |
| |
| backend = fp4_backend.get_flashinfer_backend() |
| return flashinfer_fp4_gemm( |
| input, weight, input_sf, weight_sf, alpha, out_dtype, backend=backend |
| ) |
| else: |
| return cutlass_fp4_gemm(input, weight, input_sf, weight_sf, alpha, out_dtype) |
|
|
|
|
| if is_cuda() and (not is_sm120_supported()) and (fp4_quantize is not None): |
|
|
| @register_fake_if_exists("sgl_kernel::scaled_fp4_quant") |
| def _sgl_kernel_scaled_fp4_quant_fake( |
| output, input, output_scale, input_global_scale |
| ): |
| return |
|
|
|
|
| |
| FP4_GEMM_ALIGNMENT = 32 |
|
|
|
|
| def round_up_to_multiple(x: int, m: int) -> int: |
| """Round up x to the nearest multiple of m.""" |
| return (x + m - 1) // m * m |
|
|
|
|
| def pad_nvfp4_weight( |
| weight: torch.Tensor, |
| n_alignment: int = FP4_GEMM_ALIGNMENT, |
| k_alignment: int = FP4_GEMM_ALIGNMENT, |
| ) -> tuple[torch.Tensor, int]: |
| """ |
| Pad packed NVFP4 weights to satisfy alignment constraints for FP4 GEMM kernels. |
| |
| Different backends have different alignment requirements: |
| - CUTLASS/cuDNN: N % 32 == 0, K % 32 == 0 |
| - TRTLLM: N % 128 == 0 (for shuffle_matrix_sf_a), K padding handled separately |
| |
| Args: |
| weight: Packed FP4 weight tensor of shape [N, K//2] (2 FP4 values per byte) |
| n_alignment: Required alignment for N dimension (default 32, use 128 for TRTLLM) |
| k_alignment: Required alignment for K dimension (default 32, use 0 to skip) |
| |
| Returns: |
| Tuple of (padded_weight, weights_padding_cols) where weights_padding_cols |
| is the number of columns added for K-dimension padding (in bytes). |
| """ |
| weight_current_rows = weight.shape[0] |
| weight_current_col_bytes = weight.shape[1] |
|
|
| |
| pad_rows = 0 |
| if n_alignment > 0 and weight_current_rows % n_alignment != 0: |
| total_rows = round_up_to_multiple(weight_current_rows, n_alignment) |
| pad_rows = total_rows - weight_current_rows |
|
|
| |
| |
| weight_current_col_elements = weight_current_col_bytes * 2 |
| pad_cols_bytes = 0 |
| if k_alignment > 0 and weight_current_col_elements % k_alignment != 0: |
| total_cols = round_up_to_multiple(weight_current_col_elements, k_alignment) |
| pad_cols = total_cols - weight_current_col_elements |
| |
| pad_cols_bytes = pad_cols // 2 |
|
|
| |
| |
| if pad_rows > 0 or pad_cols_bytes > 0: |
| weight = torch.nn.functional.pad( |
| weight, (0, pad_cols_bytes, 0, pad_rows) |
| ).contiguous() |
|
|
| return weight, pad_cols_bytes |
|
|
|
|
| def pad_nvfp4_activation_for_cutlass( |
| x_fp4: torch.Tensor, |
| weights_padding_cols: int, |
| ) -> torch.Tensor: |
| """ |
| Pad packed FP4 activations to match the K-dimension padding applied to weights. |
| |
| Args: |
| x_fp4: Packed FP4 activation tensor |
| weights_padding_cols: Number of padding columns (in bytes) from weight padding |
| |
| Returns: |
| Padded activation tensor |
| """ |
| if weights_padding_cols > 0: |
| return torch.nn.functional.pad(x_fp4, (0, weights_padding_cols)).contiguous() |
| return x_fp4 |
|
|
|
|
| def slice_nvfp4_output( |
| out: torch.Tensor, |
| output_size: int, |
| ) -> torch.Tensor: |
| """ |
| Slice the output tensor to remove padding in N dimension if weight was padded. |
| |
| Args: |
| out: Output tensor from FP4 GEMM |
| output_size: Original output size before padding |
| |
| Returns: |
| Sliced output tensor with padding removed |
| """ |
| if out.shape[-1] != output_size: |
| return out[..., :output_size].contiguous() |
| return out |
|
|
|
|
| |
| MOE_NVFP4_DISPATCH = envs.SGLANG_MOE_NVFP4_DISPATCH.get() |
| |
| ACTIVATION_SCHEMES = ["static"] |
|
|
|
|
| _SUPPORTED_ACT_STRS = ("silu", "relu2", "gelu") |
|
|
|
|
| class ModelOptQuantConfig(QuantizationConfig): |
| def __init__( |
| self, |
| kv_cache_quant_algo: Optional[str], |
| exclude_modules: Optional[List[str]], |
| packed_modules_mapping: Optional[Dict[str, List[str]]], |
| ): |
| super().__init__() |
| self.packed_modules_mapping = packed_modules_mapping |
| self.exclude_modules = exclude_modules or [] |
| self.kv_cache_quant_algo = kv_cache_quant_algo |
| self.use_per_token_activation = False |
|
|
| def _get_quant_method( |
| self, |
| layer: torch.nn.Module, |
| prefix: str, |
| *, |
| Linear: type[LinearMethodBase], |
| Moe: type[FusedMoEMethodBase], |
| ) -> Optional[QuantizeMethodBase]: |
| from sglang.srt.layers.linear import LinearBase |
| from sglang.srt.layers.moe.fused_moe_triton import FusedMoE |
|
|
| if isinstance(layer, LinearBase): |
| if is_layer_skipped( |
| prefix, self.exclude_modules, self.packed_modules_mapping |
| ) or self.is_layer_excluded(prefix): |
| return UnquantizedLinearMethod() |
| return Linear(self) |
| elif self.kv_cache_quant_algo and isinstance(layer, RadixAttention): |
| return ModelOptFp8KVCacheMethod(self) |
| elif isinstance(layer, FusedMoE): |
| |
| |
| if self.is_layer_excluded(prefix): |
| |
| return None |
| return Moe(self) |
| return None |
|
|
| @classmethod |
| def get_config_filenames(cls) -> List[str]: |
| return ["hf_quant_config.json"] |
|
|
| def get_scaled_act_names(self) -> List[str]: |
| return [] |
|
|
| def apply_weight_name_mapper( |
| self, hf_to_sglang_mapper: WeightsMapper |
| ): |
| |
| |
| |
| if self.exclude_modules: |
| mapped = hf_to_sglang_mapper.apply_list(self.exclude_modules) |
| expanded: List[str] = [] |
| for name in mapped: |
| expanded.append(name) |
| if name.startswith("language_model."): |
| expanded.append(name.removeprefix("language_model.")) |
| |
| self.exclude_modules = list(dict.fromkeys(expanded)) |
|
|
| def is_layer_excluded(self, prefix: str) -> bool: |
| """Check if a layer should be excluded from quantization. |
| |
| Handles: |
| - Exact matches (e.g., "lm_head" matching prefix "lm_head") |
| - Glob-style wildcards (e.g., "mtp*" matching "mtp_layers") |
| - Part-by-part matching (split prefix on "." and check each part) |
| - language_model. prefix stripping for vision-language models |
| - Fused module patterns (e.g., "q_a_proj" in "fused_qkv_a_proj_with_mqa") |
| """ |
| if not self.exclude_modules: |
| return False |
|
|
| |
| prefixes_to_check = [prefix] |
| if prefix.startswith("language_model."): |
| prefixes_to_check.append(prefix.removeprefix("language_model.")) |
|
|
| |
| |
| |
| |
| fused_patterns = {"q_a_proj", "q_b_proj", "kv_a_proj_with_mqa", "kv_b_proj"} |
|
|
| for pattern in self.exclude_modules: |
| |
| regex_str = pattern.replace(".", r"\.").replace("*", r".*") |
|
|
| for pfx in prefixes_to_check: |
| if re.fullmatch(regex_str, pfx): |
| return True |
| |
| pfx_parts = pfx.split(".") |
| for part in pfx_parts: |
| if re.fullmatch(regex_str, part): |
| return True |
|
|
| |
| |
| |
| pattern_tail = pattern.rsplit(".", maxsplit=1)[-1] |
| if pattern_tail in fused_patterns: |
| for pfx in prefixes_to_check: |
| if pattern_tail in pfx.rsplit(".", maxsplit=1)[-1]: |
| return True |
|
|
| return False |
|
|
|
|
| class ModelOptFp8Config(ModelOptQuantConfig): |
| """Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks.""" |
|
|
| def __init__( |
| self, |
| is_checkpoint_fp8_serialized: bool = False, |
| kv_cache_quant_method: Optional[str] = None, |
| exclude_modules: Optional[List[str]] = None, |
| packed_modules_mapping: Optional[Dict[str, List[str]]] = None, |
| ) -> None: |
| """ |
| Args: |
| is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format. |
| """ |
| super().__init__(kv_cache_quant_method, exclude_modules, packed_modules_mapping) |
| self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized |
| if is_checkpoint_fp8_serialized: |
| logger.warning( |
| "Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change." |
| ) |
|
|
| @classmethod |
| def override_quantization_method(cls, hf_quant_config, user_quant): |
| """Override quantization method based on the model's config.""" |
| return cls._modelopt_override_quantization_method(hf_quant_config, user_quant) |
|
|
| @classmethod |
| def get_name(cls) -> str: |
| return "modelopt_fp8" |
|
|
| @classmethod |
| def get_supported_act_dtypes(cls) -> List[torch.dtype]: |
| return [torch.bfloat16, torch.half] |
|
|
| @classmethod |
| def get_min_capability(cls) -> int: |
| return 89 |
|
|
| @classmethod |
| def from_config(cls, config: Dict[str, Any]) -> ModelOptFp8Config: |
| |
| |
| |
| |
| |
|
|
| |
| kv_cache_quant_method = None |
| exclude_modules = None |
|
|
| |
| quant_method = config.get("quant_algo") |
| if quant_method is not None: |
| |
| |
| kv_cache_scheme = config.get("kv_cache_scheme") |
| if isinstance(kv_cache_scheme, dict): |
| if ( |
| kv_cache_scheme.get("type") == "float" |
| and kv_cache_scheme.get("num_bits") == 8 |
| ): |
| kv_cache_quant_method = "FP8" |
|
|
| |
| exclude_modules = config.get("ignore") |
| else: |
| |
| try: |
| quantization_section = cls.get_from_keys(config, ["quantization"]) |
| quant_method = quantization_section.get("quant_algo") |
| kv_cache_quant_method = quantization_section.get("kv_cache_quant_algo") |
| exclude_modules = quantization_section.get("exclude_modules") |
| except ValueError: |
| raise ValueError( |
| "Cannot find 'quant_algo' in the model's quantization config. " |
| "Expected either flat format (config.json) or nested format (hf_quant_config.json)." |
| ) |
| if quant_method is None: |
| raise ValueError( |
| "Cannot find 'quant_algo' in the model's quantization config. " |
| ) |
| if "FP8" not in quant_method: |
| raise ValueError( |
| "ModelOptFp8Config only supports static FP8 quantization in SGLang. " |
| "For FP4 quantization, use ModelOptFp4Config. " |
| "Check the quantization config for your model's configuration." |
| ) |
|
|
| return cls( |
| is_checkpoint_fp8_serialized=True, |
| kv_cache_quant_method=kv_cache_quant_method, |
| exclude_modules=exclude_modules, |
| packed_modules_mapping=config.get("packed_modules_mapping"), |
| ) |
|
|
| def get_quant_method( |
| self, layer: torch.nn.Module, prefix: str |
| ) -> Optional[QuantizeMethodBase]: |
| return self._get_quant_method( |
| layer, prefix, Linear=ModelOptFp8LinearMethod, Moe=ModelOptFp8MoEMethod |
| ) |
|
|
|
|
| class ModelOptFp8LinearMethod(LinearMethodBase): |
| """Linear method for ModelOpt static FP8 quantization. |
| |
| Supports loading FP8 checkpoints with static weight and activation scales. |
| Future support may include dynamic scales. |
| |
| **Limitations**: |
| 1. Only supports per-tensor quantization due to `torch._scaled_mm` limitations. |
| 2. Only supports the `float8_e4m3fn` data type. |
| |
| Args: |
| quant_config (ModelOptFp8Config): The ModelOpt quantization configuration. |
| """ |
|
|
| def __init__(self, quant_config: ModelOptFp8Config): |
| super().__init__() |
| self.quant_config = quant_config |
| self.cutlass_fp8_supported = cutlass_fp8_supported() |
|
|
| def create_weights( |
| self, |
| layer: torch.nn.Module, |
| input_size_per_partition: int, |
| output_partition_sizes: List[int], |
| input_size: Optional[int], |
| output_size: Optional[int], |
| params_dtype: torch.dtype, |
| **extra_weight_attrs, |
| ) -> None: |
| """Creates and registers weights, weight scales, and input scales for FP8 quantization.""" |
| output_size_per_partition = sum(output_partition_sizes) |
| weight_loader = extra_weight_attrs.get("weight_loader") |
| weight_dtype = ( |
| torch.float8_e4m3fn |
| if self.quant_config.is_checkpoint_fp8_serialized |
| else params_dtype |
| ) |
|
|
| |
| layer.logical_widths = output_partition_sizes |
| layer.input_size_per_partition = input_size_per_partition |
| layer.output_size_per_partition = output_size_per_partition |
|
|
| |
| layer.register_parameter( |
| "weight", |
| ModelWeightParameter( |
| data=torch.empty( |
| output_size_per_partition, |
| input_size_per_partition, |
| dtype=weight_dtype, |
| ), |
| input_dim=1, |
| output_dim=0, |
| weight_loader=weight_loader, |
| ), |
| ) |
|
|
| if self.quant_config.is_checkpoint_fp8_serialized: |
| |
| for scale_name in ["weight_scale", "input_scale"]: |
| layer.register_parameter( |
| scale_name, |
| PerTensorScaleParameter( |
| data=torch.full( |
| (len(output_partition_sizes),), |
| torch.finfo(torch.float32).min, |
| dtype=torch.float32, |
| ), |
| weight_loader=weight_loader, |
| ), |
| ) |
|
|
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: |
| """Requantizes weights after loading using the maximum scale.""" |
| max_w_scale, quantized_weight = requantize_with_max_scale( |
| layer.weight, layer.weight_scale, layer.logical_widths |
| ) |
| layer.weight = Parameter(quantized_weight.t(), requires_grad=False) |
| |
| if self.cutlass_fp8_supported: |
| max_w_scale = convert_to_channelwise(max_w_scale, layer.logical_widths) |
| layer.weight_scale = Parameter(max_w_scale, requires_grad=False) |
| layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) |
|
|
| def apply( |
| self, |
| layer: torch.nn.Module, |
| x: torch.Tensor, |
| bias: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """Applies FP8 linear transformation.""" |
| return apply_fp8_linear( |
| input=x, |
| weight=layer.weight, |
| weight_scale=layer.weight_scale, |
| input_scale=layer.input_scale, |
| bias=bias, |
| cutlass_fp8_supported=self.cutlass_fp8_supported, |
| ) |
|
|
|
|
| class ModelOptFp8KVCacheMethod(BaseKVCacheMethod): |
| """ |
| Handles loading FP8 kv-cache scaling factors from modelopt quantized checkpoints. |
| """ |
|
|
| def __init__(self, quant_config: ModelOptFp8Config): |
| super().__init__(quant_config) |
|
|
|
|
| class ModelOptMixedPrecisionConfig(ModelOptQuantConfig): |
| """Configuration for ModelOpt MIXED_PRECISION checkpoints.""" |
|
|
| def __init__( |
| self, |
| kv_cache_quant_algo: Optional[str], |
| exclude_modules: Optional[List[str]], |
| packed_modules_mapping: Optional[Dict[str, List[str]]], |
| quantized_layers: Dict[str, Dict[str, Any]], |
| fp8_config: ModelOptFp8Config, |
| nvfp4_config: ModelOptFp4Config, |
| ) -> None: |
| super().__init__(kv_cache_quant_algo, exclude_modules, packed_modules_mapping) |
| self.quantized_layers = quantized_layers |
| self.fp8_config = fp8_config |
| self.nvfp4_config = nvfp4_config |
|
|
| @classmethod |
| def override_quantization_method(cls, hf_quant_config, user_quant): |
| if hf_quant_config is None: |
| return None |
| if hf_quant_config.get("quant_method", "") == "modelopt_mixed": |
| return "modelopt_mixed" |
| return None |
|
|
| @classmethod |
| def get_name(cls) -> str: |
| return "modelopt_mixed" |
|
|
| @classmethod |
| def get_supported_act_dtypes(cls) -> List[torch.dtype]: |
| return [torch.bfloat16, torch.half] |
|
|
| @classmethod |
| def get_min_capability(cls) -> int: |
| return ModelOptFp4Config.get_min_capability() |
|
|
| @classmethod |
| def from_config(cls, config: Dict[str, Any]) -> ModelOptMixedPrecisionConfig: |
| kv_cache_quant_algo = None |
| exclude_modules = None |
| quantized_layers = {} |
|
|
| quant_algo = config.get("quant_algo") |
| if quant_algo is not None: |
| kv_cache_scheme = config.get("kv_cache_scheme") |
| if isinstance(kv_cache_scheme, dict): |
| if ( |
| kv_cache_scheme.get("type") == "float" |
| and kv_cache_scheme.get("num_bits") == 8 |
| ): |
| kv_cache_quant_algo = "FP8" |
| elif ( |
| kv_cache_scheme.get("type") == "float" |
| and kv_cache_scheme.get("num_bits") == 4 |
| ): |
| kv_cache_quant_algo = "NVFP4" |
| else: |
| kv_cache_quant_algo = "auto" |
| exclude_modules = config.get("ignore") |
| quantized_layers = config.get("quantized_layers", {}) |
| else: |
| quantization_section = cls.get_from_keys(config, ["quantization"]) |
| quant_algo = quantization_section.get("quant_algo") |
| kv_cache_quant_algo = quantization_section.get("kv_cache_quant_algo") |
| exclude_modules = quantization_section.get("exclude_modules") |
| quantized_layers = quantization_section.get("quantized_layers", {}) |
|
|
| if quant_algo != "MIXED_PRECISION": |
| raise ValueError( |
| "ModelOptMixedPrecisionConfig only supports MIXED_PRECISION checkpoints." |
| ) |
| if not quantized_layers: |
| raise ValueError( |
| "MIXED_PRECISION quantization requires a non-empty quantized_layers map." |
| ) |
|
|
| group_size = None |
| for layer_info in quantized_layers.values(): |
| if layer_info.get("quant_algo", "").upper() == "NVFP4": |
| group_size = layer_info.get("group_size", 16) |
| break |
| if group_size is None: |
| group_size = 16 |
|
|
| packed_modules_mapping = config.get("packed_modules_mapping") |
| fp8_config = ModelOptFp8Config( |
| is_checkpoint_fp8_serialized=True, |
| kv_cache_quant_method=kv_cache_quant_algo, |
| exclude_modules=[], |
| packed_modules_mapping=packed_modules_mapping, |
| ) |
| nvfp4_config = ModelOptFp4Config( |
| is_checkpoint_nvfp4_serialized=True, |
| kv_cache_quant_algo=kv_cache_quant_algo, |
| exclude_modules=[], |
| packed_modules_mapping=packed_modules_mapping, |
| group_size=group_size, |
| ) |
|
|
| return cls( |
| kv_cache_quant_algo=kv_cache_quant_algo, |
| exclude_modules=exclude_modules, |
| packed_modules_mapping=packed_modules_mapping, |
| quantized_layers=quantized_layers, |
| fp8_config=fp8_config, |
| nvfp4_config=nvfp4_config, |
| ) |
|
|
| def apply_weight_name_mapper(self, hf_to_sglang_mapper: WeightsMapper): |
| super().apply_weight_name_mapper(hf_to_sglang_mapper) |
| if self.quantized_layers: |
| self.quantized_layers = hf_to_sglang_mapper.apply_dict( |
| self.quantized_layers |
| ) |
|
|
| def _resolve_quant_algo(self, prefix: str) -> Optional[str]: |
| if prefix in self.quantized_layers: |
| return self.quantized_layers[prefix]["quant_algo"].upper() |
|
|
| proj_name = prefix.rsplit(".", 1)[-1] |
| if self.packed_modules_mapping and proj_name in self.packed_modules_mapping: |
| algos = set() |
| base = prefix.rsplit(".", 1)[0] |
| for shard_name in self.packed_modules_mapping[proj_name]: |
| shard_prefix = f"{base}.{shard_name}" |
| if shard_prefix in self.quantized_layers: |
| algos.add(self.quantized_layers[shard_prefix]["quant_algo"].upper()) |
| if len(algos) == 1: |
| return algos.pop() |
| if len(algos) > 1: |
| raise ValueError( |
| f"Mixed quant_algo within fused layer {prefix}: {algos}. " |
| "All shards must use the same quantization." |
| ) |
|
|
| prefix_dot = prefix + "." |
| for key, info in self.quantized_layers.items(): |
| if key.startswith(prefix_dot): |
| return info["quant_algo"].upper() |
|
|
| return None |
|
|
| def get_quant_method( |
| self, layer: torch.nn.Module, prefix: str |
| ) -> Optional[QuantizeMethodBase]: |
| from sglang.srt.layers.linear import LinearBase |
| from sglang.srt.layers.moe.fused_moe_triton import FusedMoE |
|
|
| quant_algo = self._resolve_quant_algo(prefix) |
|
|
| if isinstance(layer, LinearBase): |
| if is_layer_skipped( |
| prefix, self.exclude_modules, self.packed_modules_mapping |
| ) or self.is_layer_excluded(prefix): |
| return UnquantizedLinearMethod() |
| if quant_algo == "FP8": |
| return ModelOptFp8LinearMethod(self.fp8_config) |
| if quant_algo == "NVFP4": |
| return ModelOptFp4LinearMethod(self.nvfp4_config) |
| return UnquantizedLinearMethod() |
|
|
| if self.kv_cache_quant_algo and isinstance(layer, RadixAttention): |
| return ModelOptFp8KVCacheMethod(self.fp8_config) |
|
|
| if isinstance(layer, FusedMoE): |
| if self.is_layer_excluded(prefix): |
| return None |
| if quant_algo == "FP8": |
| return ModelOptFp8MoEMethod(self.fp8_config) |
| if quant_algo == "NVFP4": |
| return ModelOptNvFp4FusedMoEMethod(self.nvfp4_config) |
| return None |
|
|
| return None |
|
|
|
|
| class ModelOptFp8MoEMethod(FusedMoEMethodBase): |
| """MoE method for ModelOpt FP8. |
| Supports loading FP8 checkpoints with static weight scale and activation scale. |
| |
| Args: |
| quant_config: The ModelOpt quantization config. |
| """ |
|
|
| def __init__(self, quant_config: ModelOptFp8Config): |
| self.quant_config = quant_config |
| self.cutlass_fp8_supported = cutlass_fp8_supported() |
|
|
| def create_weights( |
| self, |
| layer: torch.nn.Module, |
| num_experts: int, |
| hidden_size: int, |
| intermediate_size_per_partition: int, |
| params_dtype: torch.dtype, |
| **extra_weight_attrs, |
| ): |
| from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported |
|
|
| |
| weight_dtype = ( |
| torch.float8_e4m3fn |
| if self.quant_config.is_checkpoint_fp8_serialized |
| else params_dtype |
| ) |
| weight_loader = extra_weight_attrs.get("weight_loader") |
| num_shards = 2 if layer.moe_runner_config.is_gated else 1 |
| intermediate_size = num_shards * intermediate_size_per_partition |
| w13_weight = ModelWeightParameter( |
| data=torch.empty( |
| num_experts, |
| intermediate_size, |
| hidden_size, |
| dtype=weight_dtype, |
| ), |
| input_dim=2, |
| output_dim=1, |
| weight_loader=weight_loader, |
| ) |
| layer.register_parameter("w13_weight", w13_weight) |
|
|
| w2_weight = ModelWeightParameter( |
| data=torch.empty( |
| num_experts, |
| hidden_size, |
| intermediate_size_per_partition, |
| dtype=weight_dtype, |
| ), |
| input_dim=2, |
| output_dim=1, |
| weight_loader=weight_loader, |
| ) |
| layer.register_parameter("w2_weight", w2_weight) |
|
|
| if self.quant_config.is_checkpoint_fp8_serialized: |
| |
| |
| |
| w13_scale_shape = (num_experts, num_shards) |
| w13_weight_scale = PerTensorScaleParameter( |
| data=torch.full( |
| w13_scale_shape, |
| torch.finfo(torch.float32).min, |
| dtype=torch.float32, |
| ), |
| weight_loader=weight_loader, |
| ) |
| w2_weight_scale = PerTensorScaleParameter( |
| data=torch.full( |
| (num_experts,), torch.finfo(torch.float32).min, dtype=torch.float32 |
| ), |
| weight_loader=weight_loader, |
| ) |
| layer.register_parameter("w13_weight_scale", w13_weight_scale) |
| layer.register_parameter("w2_weight_scale", w2_weight_scale) |
|
|
| |
| extra_weight_attrs.update( |
| {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} |
| ) |
|
|
| |
| w13_input_scale = PerTensorScaleParameter( |
| data=torch.full((num_experts,), 1.0, dtype=torch.float32), |
| weight_loader=weight_loader, |
| ) |
| w2_input_scale = PerTensorScaleParameter( |
| data=torch.full((num_experts,), 1.0, dtype=torch.float32), |
| weight_loader=weight_loader, |
| ) |
| layer.register_parameter("w13_input_scale", w13_input_scale) |
| layer.register_parameter("w2_input_scale", w2_input_scale) |
|
|
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: |
| """Process FP8 MoE weights after loading from serialized checkpoint. |
| |
| Only supports pre-quantized checkpoints with FP8 weights and scales. |
| """ |
|
|
| layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False) |
| layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) |
|
|
| |
| if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None: |
| |
| |
| if layer.w13_weight_scale.dim() == 2: |
| |
| max_w13_scales = layer.w13_weight_scale.max(dim=1).values |
|
|
| |
| |
| |
| num_shards = 2 if layer.moe_runner_config.is_gated else 1 |
| intermediate_size_per_partition = ( |
| layer.w13_weight.shape[1] // num_shards |
| ) |
| for expert_id in range(layer.w13_weight.shape[0]): |
| start = 0 |
| for shard_id in range(num_shards): |
| |
| dq_weight = per_tensor_dequantize( |
| layer.w13_weight[expert_id][ |
| start : start + intermediate_size_per_partition, : |
| ], |
| layer.w13_weight_scale[expert_id][shard_id], |
| ) |
| |
| ( |
| layer.w13_weight[expert_id][ |
| start : start + intermediate_size_per_partition, : |
| ], |
| _, |
| ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) |
|
|
| start += intermediate_size_per_partition |
|
|
| |
| layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False) |
| else: |
| layer.w13_weight_scale = Parameter( |
| layer.w13_weight_scale.data, requires_grad=False |
| ) |
|
|
| if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None: |
| layer.w2_weight_scale = Parameter( |
| layer.w2_weight_scale.data, requires_grad=False |
| ) |
| if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None: |
| layer.w13_input_scale = Parameter( |
| layer.w13_input_scale.max(), requires_grad=False |
| ) |
| if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None: |
| layer.w2_input_scale = Parameter( |
| layer.w2_input_scale.max(), requires_grad=False |
| ) |
|
|
| |
| if get_moe_runner_backend().is_flashinfer_trtllm(): |
| from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import ( |
| align_fp8_moe_weights_for_flashinfer_trtllm, |
| ) |
|
|
| |
| align_fp8_moe_weights_for_flashinfer_trtllm(layer, swap_w13_halves=True) |
| elif get_moe_runner_backend().is_flashinfer_cutlass(): |
| assert ( |
| hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None |
| ) |
| assert hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None |
| assert ( |
| hasattr(layer, "w13_weight_scale") |
| and layer.w13_weight_scale is not None |
| ) |
| assert ( |
| hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None |
| ) |
|
|
| input_scale = layer.w13_input_scale.to(torch.float32) |
| activation_scale = layer.w2_input_scale.to(torch.float32) |
| w13_weight_scale = layer.w13_weight_scale.to(torch.float32) |
| w2_weight_scale = layer.w2_weight_scale.to(torch.float32) |
|
|
| layer.fc1_dequant = Parameter( |
| w13_weight_scale * input_scale, requires_grad=False |
| ) |
| layer.fc2_quant = Parameter( |
| activation_scale.reciprocal(), requires_grad=False |
| ) |
| layer.fc2_dequant = Parameter( |
| activation_scale * w2_weight_scale, requires_grad=False |
| ) |
| layer.fc1_input_dequant = Parameter(input_scale, requires_grad=False) |
|
|
| |
| |
| |
| |
| |
| num_shards = 2 if layer.moe_runner_config.is_gated else 1 |
| isp = layer.w13_weight.shape[1] // num_shards |
| if isp % 16 != 0: |
| pad_amount = round_up(isp, 16) - isp |
| w13_data = layer.w13_weight.data |
| if num_shards == 2: |
| up_weight = w13_data[:, :isp, :] |
| gate_weight = w13_data[:, isp:, :] |
| layer.w13_weight = Parameter( |
| torch.cat( |
| [ |
| torch.nn.functional.pad( |
| up_weight, (0, 0, 0, pad_amount) |
| ), |
| torch.nn.functional.pad( |
| gate_weight, (0, 0, 0, pad_amount) |
| ), |
| ], |
| dim=1, |
| ), |
| requires_grad=False, |
| ) |
| else: |
| layer.w13_weight = Parameter( |
| torch.nn.functional.pad(w13_data, (0, 0, 0, pad_amount)), |
| requires_grad=False, |
| ) |
| layer.w2_weight = Parameter( |
| torch.nn.functional.pad(layer.w2_weight.data, (0, pad_amount)), |
| requires_grad=False, |
| ) |
|
|
| def create_moe_runner( |
| self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig |
| ): |
| self.moe_runner_config = moe_runner_config |
| self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) |
|
|
| def apply( |
| self, |
| layer: torch.nn.Module, |
| dispatch_output: StandardDispatchOutput, |
| ) -> CombineInput: |
| x = dispatch_output.hidden_states |
| topk_output = dispatch_output.topk_output |
| from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput |
| from sglang.srt.layers.moe.topk import TopKOutputChecker |
|
|
| |
|
|
| if ( |
| get_moe_runner_backend().is_flashinfer_trtllm() |
| and TopKOutputChecker.format_is_bypassed(topk_output) |
| ): |
| from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import ( |
| FlashInferTrtllmFp8MoeQuantInfo, |
| fused_experts_none_to_flashinfer_trtllm_fp8, |
| ) |
| from sglang.srt.layers.moe.utils import RoutingMethodType |
|
|
| topk_config = topk_output.topk_config |
|
|
| from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import ( |
| get_activation_type, |
| ) |
|
|
| _SUPPORTED_FP8_ACTIVATIONS = {"silu", "relu2"} |
| assert self.moe_runner_config.activation in _SUPPORTED_FP8_ACTIVATIONS, ( |
| f"Only {_SUPPORTED_FP8_ACTIVATIONS} are supported for " |
| f"flashinfer trtllm fp8 moe, got '{self.moe_runner_config.activation}'" |
| ) |
|
|
| routing_method_type = getattr( |
| layer, "routing_method_type", RoutingMethodType.Llama4 |
| ) |
|
|
| quant_info = FlashInferTrtllmFp8MoeQuantInfo( |
| w13_weight=layer.w13_weight, |
| w2_weight=layer.w2_weight, |
| global_num_experts=layer.num_experts, |
| local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, |
| local_num_experts=layer.num_local_experts, |
| intermediate_size=layer.w2_weight.shape[2], |
| routing_method_type=routing_method_type, |
| block_quant=False, |
| w13_input_scale=layer.w13_input_scale, |
| output1_scales_scalar=layer.output1_scales_scalar, |
| output1_scales_gate_scalar=layer.output1_scales_gate_scalar, |
| output2_scales_scalar=layer.output2_scales_scalar, |
| use_routing_scales_on_input=True, |
| activation_type=get_activation_type( |
| self.moe_runner_config.activation, |
| is_gated=self.moe_runner_config.is_gated, |
| ), |
| ) |
|
|
| return fused_experts_none_to_flashinfer_trtllm_fp8( |
| dispatch_output, quant_info, self.moe_runner_config |
| ) |
|
|
| if get_moe_runner_backend().is_flashinfer_cutlass(): |
| from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import ( |
| get_activation_type, |
| ) |
|
|
| activation_str = self.moe_runner_config.activation |
| assert activation_str in _SUPPORTED_ACT_STRS, ( |
| f"Activation {activation_str!r} is not supported for " |
| f"flashinfer cutlass fp8 moe (supported: {_SUPPORTED_ACT_STRS})." |
| ) |
| activation = ActivationType( |
| get_activation_type( |
| activation_str, is_gated=self.moe_runner_config.is_gated |
| ) |
| ) |
| |
| |
| _CUTLASS_SUPPORTED = { |
| ActivationType.Swiglu, |
| ActivationType.Geglu, |
| ActivationType.Relu2, |
| ActivationType.Identity, |
| } |
| assert activation in _CUTLASS_SUPPORTED, ( |
| f"Activation {activation_str!r} (is_gated=" |
| f"{self.moe_runner_config.is_gated}) maps to {activation.name}, " |
| "which is not supported by flashinfer cutlass fp8 moe." |
| ) |
| topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids |
| x_fp8, _ = scaled_fp8_quant(x, layer.w13_input_scale) |
| output_dtype = x.dtype |
| original_col = x.shape[1] |
| x_sf = None |
|
|
| with use_symmetric_memory( |
| get_tp_group(), disabled=not is_allocation_symmetric() |
| ): |
| symm_output = torch.empty( |
| x.shape[0], original_col, dtype=output_dtype, device=x.device |
| ) |
| output = flashinfer_cutlass_fused_moe( |
| output=symm_output, |
| input=x_fp8, |
| token_selected_experts=topk_ids.to(torch.int), |
| token_final_scales=topk_weights, |
| fc1_expert_weights=layer.w13_weight, |
| fc2_expert_weights=layer.w2_weight, |
| output_dtype=output_dtype, |
| input_sf=x_sf, |
| quant_scales=[ |
| layer.fc1_dequant, |
| layer.fc2_quant, |
| layer.fc2_dequant, |
| layer.fc1_input_dequant, |
| ], |
| ep_size=layer.moe_ep_size, |
| ep_rank=layer.moe_ep_rank, |
| tp_size=layer.moe_tp_size, |
| tp_rank=layer.moe_tp_rank, |
| tune_max_num_tokens=next_power_of_2(x.shape[0]), |
| activation_type=activation, |
| )[0] |
|
|
| if ( |
| not layer.should_fuse_routed_scaling_factor_in_topk |
| and self.moe_runner_config.routed_scaling_factor is not None |
| ): |
| output.mul_(self.moe_runner_config.routed_scaling_factor) |
|
|
| return StandardCombineInput(hidden_states=output) |
|
|
| quant_info = TritonMoeQuantInfo( |
| w13_weight=layer.w13_weight, |
| w2_weight=layer.w2_weight, |
| use_fp8_w8a8=True, |
| per_channel_quant=False, |
| w13_scale=layer.w13_weight_scale, |
| w2_scale=layer.w2_weight_scale, |
| a13_scale=layer.w13_input_scale, |
| a2_scale=layer.w2_input_scale, |
| ) |
|
|
| return self.runner.run(dispatch_output, quant_info) |
|
|
|
|
| class ModelOptFp4Config(ModelOptQuantConfig): |
| """Config class for FP4.""" |
|
|
| def __init__( |
| self, |
| is_checkpoint_nvfp4_serialized: bool = False, |
| kv_cache_quant_algo: str = None, |
| group_size: int = None, |
| exclude_modules: List[str] = None, |
| packed_modules_mapping: Optional[Dict[str, List[str]]] = None, |
| use_per_token_activation: Optional[bool] = None, |
| ) -> None: |
| super().__init__(kv_cache_quant_algo, exclude_modules, packed_modules_mapping) |
| self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized |
| if is_checkpoint_nvfp4_serialized: |
| logger.warning( |
| "Detected nvfp4 checkpoint. Please note that the " |
| "format is experimental and subject to change." |
| ) |
| self.group_size = group_size |
| self.use_per_token_activation = ( |
| use_per_token_activation |
| or envs.SGLANG_FLASHINFER_NVFP4_PER_TOKEN_ACTIVATION.get() |
| ) |
|
|
| @classmethod |
| def override_quantization_method(cls, hf_quant_config, user_quant): |
| """Override quantization method based on the model's config.""" |
| return cls._modelopt_override_quantization_method(hf_quant_config, user_quant) |
|
|
| @classmethod |
| def get_name(cls) -> str: |
| return "modelopt_fp4" |
|
|
| @classmethod |
| def get_supported_act_dtypes(cls) -> List[torch.dtype]: |
| return [torch.bfloat16, torch.half, torch.float8_e4m3fn] |
|
|
| @classmethod |
| def get_min_capability(cls) -> int: |
| return 80 |
|
|
| @staticmethod |
| def common_group_size(cfg: dict) -> int: |
| """Return the unique group_size across the config; raise if missing/mismatched.""" |
| sizes = set() |
|
|
| |
| v = cfg.get("group_size") |
| if isinstance(v, int): |
| sizes.add(v) |
| q = cfg.get("quantization") |
| if isinstance(q, dict): |
| v = q.get("group_size") |
| if isinstance(v, int): |
| sizes.add(v) |
|
|
| |
| for g in (cfg.get("config_groups") or {}).values(): |
| if isinstance(g, dict): |
| v = g.get("group_size") |
| if isinstance(v, int): |
| sizes.add(v) |
| for sub in g.values(): |
| if isinstance(sub, dict): |
| v = sub.get("group_size") |
| if isinstance(v, int): |
| sizes.add(v) |
|
|
| if not sizes: |
| raise ValueError("No group_size found in config.") |
| if len(sizes) > 1: |
| raise ValueError(f"Inconsistent group_size values: {sorted(sizes)}") |
| return next(iter(sizes)) |
|
|
| @classmethod |
| def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config: |
| |
| |
| |
| |
| |
|
|
| |
| kv_cache_quant_algo = None |
| group_size = None |
| exclude_modules = [] |
|
|
| |
| quant_method = config.get("quant_algo") |
| if quant_method is not None: |
| |
| |
| kv_cache_scheme = config.get("kv_cache_scheme") |
| if isinstance(kv_cache_scheme, dict): |
| if ( |
| kv_cache_scheme.get("type") == "float" |
| and kv_cache_scheme.get("num_bits") == 8 |
| ): |
| kv_cache_quant_algo = "FP8" |
| else: |
| kv_cache_quant_algo = "auto" |
| elif isinstance(kv_cache_scheme, str): |
| scheme_name = kv_cache_scheme.strip().upper() |
| if scheme_name in ("FP8", "FLOAT8"): |
| kv_cache_quant_algo = "FP8" |
| elif scheme_name in ("FP4", "FLOAT4", "NVFP4"): |
| kv_cache_quant_algo = "NVFP4" |
| else: |
| kv_cache_quant_algo = "auto" |
| else: |
| kv_cache_quant_algo = "auto" |
|
|
| group_size = config.get("group_size") |
| |
| if group_size is None: |
| config_groups = config.get("config_groups", {}) |
| if config_groups: |
| |
| first_group = next(iter(config_groups.values()), {}) |
| weights_config = first_group.get("weights", {}) |
| group_size = weights_config.get("group_size") |
|
|
| exclude_modules = config.get("ignore", []) |
| else: |
| |
| try: |
| quant_config = cls.get_from_keys(config, ["quantization"]) |
| quant_method = quant_config["quant_algo"] |
| kv_cache_quant_algo = quant_config.get("kv_cache_quant_algo") |
| if not kv_cache_quant_algo: |
| kv_cache_quant_algo = "auto" |
| group_size = ModelOptFp4Config.common_group_size(config) |
| exclude_modules = quant_config.get("exclude_modules", []) |
| except (ValueError, KeyError): |
| raise ValueError( |
| "Cannot find 'quant_algo' in the model's quantization config. " |
| "Expected either flat format (config.json) or nested format (hf_quant_config.json)." |
| ) |
|
|
| if not quant_method in ["FP8", "NVFP4"]: |
| raise ValueError( |
| f"ModelOpt currently only supports: FP8, NVFP4" |
| " quantizations in sglang. Please check the " |
| "quantization config for your model's configuration." |
| ) |
| is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method |
|
|
| if group_size is None or exclude_modules is None: |
| logger.warning( |
| f"group_size: {group_size}," |
| f"kv_cache_quant_algo: {kv_cache_quant_algo}," |
| f"exclude_modules: {exclude_modules}" |
| ) |
| raise ValueError( |
| "NVFP4 quantization requires group_size and exclude_modules " |
| "specified in the quantization config" |
| ) |
| return cls( |
| is_checkpoint_nvfp4_serialized, |
| kv_cache_quant_algo, |
| group_size, |
| exclude_modules, |
| config.get("packed_modules_mapping"), |
| ) |
|
|
| def get_quant_method(self, layer: torch.nn.Module, prefix: str): |
| return self._get_quant_method( |
| layer, |
| prefix, |
| Linear=ModelOptFp4LinearMethod, |
| Moe=ModelOptNvFp4FusedMoEMethod, |
| ) |
|
|
|
|
| class ModelOptFp4LinearMethod(LinearMethodBase): |
| """Linear method for NVFP4. |
| Supports loading NVFP4 checkpoints with the following structure: |
| |
| |Tensor Name | datatype | shape | |
| |----------------------------------------------------| |
| |input_scale | torch.float32 | scalar | |
| |weight | NVFP4(SE2M1) | [1, X, y/2] | |
| |weight_scale | FP8-E4M3 | [X, Y] | |
| |weight_scale_2 | torch.float32 | scalar | |
| |
| The weights are quantized per block of 16 elements. |
| Args: quant_config: The ModelOpt quantization config. |
| """ |
|
|
| def __init__(self, quant_config: ModelOptFp4Config): |
| self.quant_config = quant_config |
|
|
| def create_weights( |
| self, |
| layer: torch.nn.Module, |
| input_size_per_partition: int, |
| output_partition_sizes: List[int], |
| input_size: int, |
| output_size: int, |
| params_dtype: torch.dtype, |
| **extra_weight_attrs, |
| ): |
| del input_size, output_size |
| if not self.quant_config.is_checkpoint_nvfp4_serialized: |
| raise ValueError( |
| "NVFP4 quantization was selected, " |
| " dynamic quantization is not supported." |
| ) |
|
|
| output_size_per_partition = sum(output_partition_sizes) |
| weight_loader = extra_weight_attrs.get("weight_loader") |
|
|
| layer.logical_widths = output_partition_sizes |
|
|
| layer.input_size_per_partition = input_size_per_partition |
| layer.output_size_per_partition = output_size_per_partition |
| layer.params_dtype = params_dtype |
| layer.quant_config = self.quant_config |
| if input_size_per_partition % 16 != 0: |
| raise ValueError( |
| "Unsupported model when in features size is not multiple of 16" |
| ) |
|
|
| weight_dtype = ( |
| torch.float8_e4m3fn |
| if self.quant_config.is_checkpoint_nvfp4_serialized |
| else params_dtype |
| ) |
|
|
| weight = ModelWeightParameter( |
| data=torch.empty( |
| |
| output_size_per_partition, |
| input_size_per_partition // 2, |
| dtype=torch.uint8, |
| ), |
| input_dim=1, |
| output_dim=0, |
| weight_loader=weight_loader, |
| ) |
| layer.register_parameter("weight", weight) |
|
|
| input_scale = PerTensorScaleParameter( |
| data=torch.empty(len(output_partition_sizes), dtype=torch.float32), |
| weight_loader=weight_loader, |
| ) |
|
|
| layer.register_parameter("input_scale", input_scale) |
|
|
| weight_scale_2 = PerTensorScaleParameter( |
| data=torch.empty(len(output_partition_sizes), dtype=torch.float32), |
| weight_loader=weight_loader, |
| ) |
| layer.register_parameter("weight_scale_2", weight_scale_2) |
|
|
| weight_scale = ModelWeightParameter( |
| data=torch.empty( |
| output_size_per_partition, |
| input_size_per_partition // self.quant_config.group_size, |
| dtype=weight_dtype, |
| ), |
| input_dim=1, |
| output_dim=0, |
| weight_loader=weight_loader, |
| ) |
|
|
| layer.register_parameter("weight_scale", weight_scale) |
|
|
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: |
| input_scale_2 = layer.input_scale.max().to(torch.float32) |
| weight_scale_2 = layer.weight_scale_2.max().to(torch.float32) |
|
|
| |
| |
| |
| copy_or_rebind_param( |
| layer, "alpha", (input_scale_2 * weight_scale_2).to(torch.float32) |
| ) |
| copy_or_rebind_param( |
| layer, "input_scale_inv", (1 / input_scale_2).to(torch.float32) |
| ) |
|
|
| |
| layer.output_size_per_partition = layer.weight.shape[0] |
|
|
| if get_fp4_gemm_runner_backend().is_marlin(): |
| if self.quant_config.group_size != 16: |
| raise ValueError( |
| f"NVFP4 Marlin requires group_size=16, got {self.quant_config.group_size}." |
| ) |
| copy_or_rebind_param(layer, "input_global_scale", input_scale_2) |
| copy_or_rebind_param(layer, "weight_global_scale", weight_scale_2) |
| prepare_nvfp4_layer_for_marlin(layer) |
| layer.weights_padding_cols = 0 |
| return |
|
|
| if not is_blackwell_supported(): |
| raise ValueError( |
| "ModelOpt NVFP4 native dense GEMM backends require SM100+. " |
| "Use --fp4-gemm-backend marlin on SM80-SM90." |
| ) |
|
|
| if get_fp4_gemm_runner_backend().is_flashinfer_trtllm(): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a |
|
|
| |
| weight, _ = pad_nvfp4_weight( |
| layer.weight.data, n_alignment=128, k_alignment=0 |
| ) |
| |
| scale = layer.weight_scale |
| if scale.shape[0] != weight.shape[0]: |
| pad_n = weight.shape[0] - scale.shape[0] |
| scale = torch.nn.functional.pad(scale, (0, 0, 0, pad_n)) |
|
|
| |
| scale_k = scale.shape[1] |
| weights_padding_cols = 0 |
| if scale_k % 4 != 0: |
| padded_scale_k = round_up_to_multiple(scale_k, 4) |
| pad_scale_k = padded_scale_k - scale_k |
| |
| scale = torch.nn.functional.pad(scale, (0, pad_scale_k, 0, 0)) |
| |
| pad_weight_k = pad_scale_k * 8 |
| weight = torch.nn.functional.pad(weight, (0, pad_weight_k, 0, 0)) |
| |
| weights_padding_cols = pad_weight_k |
|
|
| |
| epilogue_tile_m = 128 |
| shuffled_scale_shape = scale.shape |
| weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m) |
| scale = ( |
| shuffle_matrix_sf_a(scale.view(torch.uint8), epilogue_tile_m) |
| .reshape(shuffled_scale_shape) |
| .view(torch.float8_e4m3fn) |
| ) |
|
|
| alias_or_bind_derived_param( |
| layer, "weight_scale", "weight_scale_interleaved", scale |
| ) |
| copy_or_rebind_param(layer, "weight", weight) |
| layer.weights_padding_cols = weights_padding_cols |
| return |
|
|
| |
| weight, weights_padding_cols = pad_nvfp4_weight(layer.weight.data) |
| layer.weights_padding_cols = weights_padding_cols |
| copy_or_rebind_param(layer, "weight", weight) |
|
|
| |
| scales = layer.weight_scale |
| scale_ndim = scales.ndim |
| if scale_ndim == 2: |
| scales = scales.unsqueeze(0) |
| assert scales.ndim == 3 |
| B, M, K = scales.shape |
| M_padded = round_up_to_multiple(M, 128) |
| K_padded = round_up_to_multiple(K, 4) |
| padded_scales = torch.zeros((B, M_padded, K_padded), dtype=scales.dtype) |
| padded_scales[:B, :M, :K] = scales |
|
|
| |
| |
| |
| |
| raw_scale_snapshot = ( |
| (scales.squeeze(0) if scale_ndim == 2 else scales).detach().clone() |
| ) |
|
|
| batches, rows, cols = padded_scales.shape |
| assert rows % 128 == 0 |
| assert cols % 4 == 0 |
| padded_scales = padded_scales.reshape(batches, rows // 128, 4, 32, cols // 4, 4) |
| padded_scales = padded_scales.permute((0, 1, 4, 3, 2, 5)) |
| padded_scales = padded_scales.contiguous().cuda() |
| padded_scales = ( |
| padded_scales.reshape(M_padded, K_padded) |
| if scale_ndim == 2 |
| else padded_scales.reshape(B, M_padded, K_padded) |
| ) |
| alias_or_bind_derived_param( |
| layer, "weight_scale", "weight_scale_interleaved", padded_scales |
| ) |
|
|
| if getattr(layer, "_interleave_for_swiglu_fusion", False): |
| from sglang.srt.layers.quantization.nvfp4_gemm_swiglu_nvfp4_quant import ( |
| interleave_linear_and_gate, |
| swizzle_blockscale_2d, |
| ) |
|
|
| w = layer.weight.data |
| assert weights_padding_cols == 0, ( |
| "_interleave_for_swiglu_fusion does not support K-padded weights; " |
| f"got weights_padding_cols={weights_padding_cols}." |
| ) |
| assert raw_scale_snapshot.shape[0] == w.shape[0], ( |
| "_interleave_for_swiglu_fusion requires no N-padding; " |
| f"raw_scale rows={raw_scale_snapshot.shape[0]} vs weight rows={w.shape[0]}." |
| ) |
| assert w.shape[0] % 128 == 0, ( |
| "_interleave_for_swiglu_fusion requires N % 128 == 0 (group_size=64 " |
| f"with gate+up halves); got N={w.shape[0]}." |
| ) |
|
|
| gate_w, up_w = w.chunk(2, dim=0) |
| w_swiglu = interleave_linear_and_gate( |
| torch.cat((up_w, gate_w), dim=0), group_size=64, dim=0 |
| ) |
|
|
| gate_s, up_s = raw_scale_snapshot.chunk(2, dim=0) |
| w_scale_swiglu = swizzle_blockscale_2d( |
| interleave_linear_and_gate( |
| torch.cat((up_s, gate_s), dim=0), group_size=64, dim=0 |
| ) |
| ) |
|
|
| layer.weight_swiglu_interleaved = w_swiglu |
| layer.weight_scale_swiglu_interleaved = w_scale_swiglu |
|
|
| |
| |
| layer.weight.data = torch.empty( |
| 0, dtype=layer.weight.dtype, device=layer.weight.device |
| ) |
| layer.weight_scale_interleaved.data = torch.empty( |
| 0, |
| dtype=layer.weight_scale_interleaved.dtype, |
| device=layer.weight_scale_interleaved.device, |
| ) |
|
|
| def apply( |
| self, |
| layer: torch.nn.Module, |
| x: torch.Tensor, |
| bias: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| if get_fp4_gemm_runner_backend().is_marlin(): |
| return apply_fp4_marlin_linear( |
| input=x, |
| weight=layer.weight, |
| weight_scale=layer.weight_scale, |
| weight_global_scale=layer.weight_global_scale, |
| workspace=layer.workspace, |
| size_n=layer.output_size_per_partition, |
| size_k=layer.input_size_per_partition, |
| bias=bias, |
| ) |
|
|
| |
| |
| if getattr(layer, "_accepts_prequantized_fp4", False) and isinstance(x, tuple): |
| x_fp4, x_scale_interleaved = x |
| x_m = x_fp4.shape[0] |
| output_dtype = layer.params_dtype |
| else: |
| x_fp4, x_scale_interleaved = fp4_quantize(x, layer.input_scale_inv) |
| x_m, _ = x.shape |
| output_dtype = x.dtype |
|
|
| output_size = layer.output_size_per_partition |
| w_n, _ = layer.weight.shape |
| output_shape = [x_m, output_size] |
|
|
| assert x_fp4.dtype == torch.uint8 |
| assert layer.weight.dtype == torch.uint8 |
| assert layer.weight_scale_interleaved.dtype == torch.float8_e4m3fn |
| assert layer.alpha.dtype == torch.float32 |
|
|
| |
| weights_padding_cols = getattr(layer, "weights_padding_cols", 0) |
| x_fp4 = pad_nvfp4_activation_for_cutlass(x_fp4, weights_padding_cols) |
|
|
| w = layer.weight |
| w_scale_interleaved = layer.weight_scale_interleaved |
| if ( |
| enable_flashinfer_fp4_gemm |
| and not get_fp4_gemm_runner_backend().is_cutlass() |
| ): |
| w = layer.weight.T |
| w_scale_interleaved = layer.weight_scale_interleaved.T |
|
|
| out = fp4_gemm( |
| x_fp4, |
| w, |
| x_scale_interleaved, |
| w_scale_interleaved, |
| layer.alpha, |
| output_dtype, |
| w_n, |
| ) |
|
|
| |
| out = slice_nvfp4_output(out, output_size) |
|
|
| if bias is not None: |
| out = out + bias |
| return out.view(*output_shape) |
|
|
|
|
| class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): |
| """ |
| MoE Method for FP4 Quantization with Blockscales and PerTensorScales |
| Args: |
| quant_config: NVFP4 Quant Config |
| """ |
|
|
| def __init__(self, quant_config: ModelOptFp4Config): |
| self.quant_config = quant_config |
| moe_runner_backend = get_moe_runner_backend() |
| if moe_runner_backend.is_auto() and is_cuda(): |
| capability = get_device_capability() |
| use_marlin_fallback = (8, 0) <= capability < (10, 0) |
| else: |
| use_marlin_fallback = moe_runner_backend.is_marlin() |
| if not is_blackwell_supported() and not use_marlin_fallback: |
| raise ValueError( |
| "Current platform does not support NVFP4" |
| " quantization with the selected MoE backend. Please use " |
| "Blackwell and above, or use moe_runner_backend=marlin on SM80+." |
| ) |
| self.enable_flashinfer_trtllm_moe = ( |
| get_moe_runner_backend().is_flashinfer_trtllm() |
| or get_moe_runner_backend().is_flashinfer_trtllm_routed() |
| ) |
| self._cache_permute_indices = {} |
|
|
| @property |
| def enable_flashinfer_cutlass_moe(self) -> bool: |
| from sglang.srt.layers.moe import get_moe_runner_backend |
|
|
| """Access the global enable_flashinfer_cutlass_moe setting.""" |
| return get_moe_runner_backend().is_flashinfer_cutlass() |
|
|
| @property |
| def enable_flashinfer_cutedsl_moe(self) -> bool: |
| """Access the global enable_flashinfer_cutedsl_moe setting.""" |
| from sglang.srt.layers.moe import get_moe_runner_backend |
|
|
| return get_moe_runner_backend().is_flashinfer_cutedsl() |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| @property |
| def _is_cutedsl_v1_deepep(self) -> bool: |
| """CuteDSL v1 + DeepEP low-latency path (masked grouped GEMM).""" |
| return is_flashinfer_cutedsl_v1_path() |
|
|
| @property |
| def _is_cutedsl_v2_standard(self) -> bool: |
| """CuteDSL v2 standard path (a2a=none or flashinfer, uses CuteDslMoEWrapper).""" |
| return self.enable_flashinfer_cutedsl_moe and not self._is_cutedsl_v1_deepep |
|
|
| def create_weights( |
| self, |
| layer: torch.nn.Module, |
| num_experts: int, |
| hidden_size: int, |
| intermediate_size_per_partition: int, |
| params_dtype: torch.dtype, |
| **extra_weight_attrs, |
| ): |
| is_nvfp4_online = getattr(self.quant_config, "is_nvfp4_online", False) |
| if not self.quant_config.is_checkpoint_nvfp4_serialized and not is_nvfp4_online: |
| raise ValueError( |
| "NVFP4 quantization was selected, " |
| " dynamic quantization is not supported." |
| ) |
| |
| |
| |
| |
| if is_nvfp4_online: |
| if not self.enable_flashinfer_trtllm_moe: |
| raise ValueError( |
| "--quantization nvfp4_online supports only " |
| "--moe-runner-backend flashinfer_trtllm or " |
| "flashinfer_trtllm_routed." |
| ) |
|
|
| |
| layer.intermediate_size_per_partition = intermediate_size_per_partition |
| layer.params_dtype = params_dtype |
| layer.quant_config = self.quant_config |
|
|
| weight_dtype = torch.uint8 |
| weight_scale_dtype = torch.float8_e4m3fn |
| weight_loader = extra_weight_attrs.get("weight_loader") |
| if is_nvfp4_online: |
| weight_loader = self.get_online_weight_loader(layer, weight_loader) |
| |
| num_shards = 2 if layer.moe_runner_config.is_gated else 1 |
|
|
| w13_weight = ModelWeightParameter( |
| data=torch.empty( |
| layer.num_local_experts, |
| num_shards * intermediate_size_per_partition, |
| |
| hidden_size // 2, |
| dtype=weight_dtype, |
| ), |
| input_dim=1, |
| output_dim=2, |
| weight_loader=weight_loader, |
| ) |
| layer.register_parameter("w13_weight", w13_weight) |
|
|
| |
| w2_weight = ModelWeightParameter( |
| data=torch.empty( |
| layer.num_local_experts, |
| hidden_size, |
| |
| intermediate_size_per_partition // 2, |
| dtype=weight_dtype, |
| ), |
| input_dim=1, |
| output_dim=2, |
| weight_loader=weight_loader, |
| ) |
| layer.register_parameter("w2_weight", w2_weight) |
|
|
| w13_weight_scale = ModelWeightParameter( |
| data=torch.empty( |
| layer.num_local_experts, |
| num_shards * intermediate_size_per_partition, |
| hidden_size // self.quant_config.group_size, |
| dtype=weight_scale_dtype, |
| ), |
| input_dim=1, |
| output_dim=2, |
| weight_loader=weight_loader, |
| ) |
| layer.register_parameter("w13_weight_scale", w13_weight_scale) |
|
|
| |
| |
| |
| if self.enable_flashinfer_trtllm_moe: |
| layer.w13_blockscale_swizzled = None |
| else: |
| layer.w13_blockscale_swizzled = Parameter( |
| swizzle_blockscale(layer.w13_weight_scale), requires_grad=False |
| ) |
|
|
| w2_weight_scale = ModelWeightParameter( |
| data=torch.empty( |
| layer.num_local_experts, |
| hidden_size, |
| intermediate_size_per_partition // self.quant_config.group_size, |
| dtype=weight_scale_dtype, |
| ), |
| input_dim=1, |
| output_dim=2, |
| weight_loader=weight_loader, |
| ) |
| layer.register_parameter("w2_weight_scale", w2_weight_scale) |
|
|
| if self.enable_flashinfer_trtllm_moe: |
| layer.w2_blockscale_swizzled = None |
| else: |
| layer.w2_blockscale_swizzled = Parameter( |
| swizzle_blockscale(layer.w2_weight_scale), requires_grad=False |
| ) |
|
|
| from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported |
|
|
| extra_weight_attrs.update( |
| {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} |
| ) |
|
|
| w13_weight_scale_shape = ( |
| (layer.num_local_experts, 2) |
| if layer.moe_runner_config.is_gated |
| else (layer.num_local_experts,) |
| ) |
| w13_weight_scale_2 = PerTensorScaleParameter( |
| data=torch.empty(w13_weight_scale_shape, dtype=torch.float32), |
| weight_loader=weight_loader, |
| ) |
| layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2) |
|
|
| w2_weight_scale_2 = PerTensorScaleParameter( |
| data=torch.empty(layer.num_local_experts, dtype=torch.float32), |
| weight_loader=weight_loader, |
| ) |
| layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2) |
|
|
| if is_nvfp4_online and self.quant_config.is_checkpoint_fp8_serialized: |
| |
| |
| |
| w13_source_weight_scale_inv = PerTensorScaleParameter( |
| data=torch.empty(0, dtype=torch.float32), |
| weight_loader=weight_loader, |
| ) |
| layer.register_parameter( |
| "w13_weight_scale_inv", w13_source_weight_scale_inv |
| ) |
| w2_source_weight_scale_inv = PerTensorScaleParameter( |
| data=torch.empty(0, dtype=torch.float32), |
| weight_loader=weight_loader, |
| ) |
| layer.register_parameter("w2_weight_scale_inv", w2_source_weight_scale_inv) |
|
|
| extra_weight_attrs.update( |
| {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} |
| ) |
|
|
| w13_input_scale_shape = (layer.num_experts, num_shards) |
| w13_input_scale = PerTensorScaleParameter( |
| data=torch.empty(w13_input_scale_shape, dtype=torch.float32), |
| weight_loader=weight_loader, |
| ) |
| w13_input_scale._sglang_require_global_experts = True |
| layer.register_parameter("w13_input_scale", w13_input_scale) |
|
|
| w2_input_scale = PerTensorScaleParameter( |
| data=torch.empty(layer.num_experts, dtype=torch.float32), |
| weight_loader=weight_loader, |
| ) |
| w2_input_scale._sglang_require_global_experts = True |
| layer.register_parameter("w2_input_scale", w2_input_scale) |
|
|
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: |
| """Process FP4 MoE weights after loading from serialized checkpoint. |
| |
| Only supports pre-quantized checkpoints with FP8 weights and scales. |
| """ |
| |
| if layer.moe_runner_config.is_gated: |
| if layer.w13_weight_scale_2.dim() == 1: |
| |
| w13_weight_scale_2 = layer.w13_weight_scale_2 |
| else: |
| if layer.w13_weight_scale_2.shape[1] >= 2 and not torch.allclose( |
| layer.w13_weight_scale_2[:, 0], |
| layer.w13_weight_scale_2[:, 1], |
| ): |
| logger.warning_once( |
| "w1_weight_scale_2 must match w3_weight_scale_2. " |
| "Accuracy may be affected." |
| ) |
|
|
| w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0] |
| else: |
| w13_weight_scale_2 = layer.w13_weight_scale_2[:] |
|
|
| moe_runner_backend = getattr( |
| self, "_moe_runner_backend", get_moe_runner_backend() |
| ) |
| if moe_runner_backend.is_marlin(): |
| copy_or_rebind_param( |
| layer, |
| "w13_weight_scale_2", |
| w13_weight_scale_2.contiguous(), |
| ) |
| prepare_moe_nvfp4_layer_for_marlin(layer) |
| return |
|
|
| |
| if self.enable_flashinfer_cutlass_moe or self.enable_flashinfer_trtllm_moe: |
| w13_input_scale = layer.w13_input_scale.max().to(torch.float32) |
| w2_input_scale = layer.w2_input_scale.max().to(torch.float32) |
| elif self.enable_flashinfer_cutedsl_moe: |
| |
| w13_input_scale = ( |
| layer.w13_input_scale.max() |
| .to(torch.float32) |
| .repeat(layer.w13_input_scale.shape[0]) |
| ) |
| w2_input_scale = layer.w2_input_scale |
|
|
| def _slice_scale(w): |
| assert w.shape == (layer.num_experts,) |
| assert layer.moe_ep_size * layer.num_local_experts == layer.num_experts |
| return w[ |
| layer.moe_ep_rank |
| * layer.num_local_experts : (layer.moe_ep_rank + 1) |
| * layer.num_local_experts |
| ] |
|
|
| w13_input_scale = _slice_scale(w13_input_scale) |
| w2_input_scale = _slice_scale(w2_input_scale) |
|
|
| if MOE_NVFP4_DISPATCH: |
| assert torch.all(w13_input_scale == w13_input_scale[0]) |
| w13_input_scale = w13_input_scale[0] |
| else: |
| w13_input_scale = layer.w13_input_scale.max(dim=-1).values.to(torch.float32) |
| w2_input_scale = layer.w2_input_scale |
|
|
| if self.quant_config.use_per_token_activation: |
| |
| |
| w13_input_scale = torch.ones_like(w13_input_scale, dtype=torch.float32) |
| w2_input_scale = torch.ones_like(w2_input_scale, dtype=torch.float32) |
|
|
| |
| copy_or_rebind_param( |
| layer, |
| "g1_alphas", |
| (w13_input_scale * w13_weight_scale_2).to(torch.float32), |
| ) |
| copy_or_rebind_param( |
| layer, |
| "g2_alphas", |
| (w2_input_scale * layer.w2_weight_scale_2).to(torch.float32), |
| ) |
| copy_or_rebind_param( |
| layer, |
| "w13_input_scale_quant", |
| (1 / w13_input_scale).to(torch.float32), |
| ) |
| copy_or_rebind_param( |
| layer, |
| "w2_input_scale_quant", |
| (1 / w2_input_scale).to(torch.float32), |
| ) |
|
|
| |
| layer.dispatcher.set_quant_config( |
| { |
| "input_global_scale": ( |
| layer.w13_input_scale_quant |
| if MOE_NVFP4_DISPATCH |
| or should_use_flashinfer_cutlass_moe_fp4_allgather() |
| else None |
| ) |
| } |
| ) |
| block_size = 16 |
| |
| assert_dim = 2 if layer.moe_runner_config.is_gated else 1 |
| for name, weight_scale in [ |
| ("w13", layer.w13_weight_scale), |
| ("w2", layer.w2_weight_scale), |
| ]: |
| |
| if get_moe_runner_backend().is_flashinfer_trtllm(): |
| expected_blocks = { |
| "w13": layer.w13_weight.shape[2] * 2 // block_size, |
| "w2": layer.w2_weight.shape[2] * 2 // block_size, |
| } |
| assert ( |
| weight_scale.shape[-1] == expected_blocks[name] |
| ), f"Expected {name}_weight_scale.dim(2) == {expected_blocks[name]}, got {weight_scale.shape[-1]}" |
| else: |
| if weight_scale.shape[assert_dim] % 4 != 0: |
| logger.warning( |
| "NVFP4 %s_weight_scale K' not multiple of 4: shape=%s, group_size=%s", |
| name, |
| tuple(weight_scale.shape), |
| getattr(self.quant_config, "group_size", None), |
| ) |
| assert ( |
| weight_scale.dtype == torch.float8_e4m3fn |
| ), f"{name} Weight Blockscale must be represented as FP8-E4M3" |
|
|
| |
| if ( |
| self.enable_flashinfer_trtllm_moe |
| and reorder_rows_for_gated_act_gemm is not None |
| and shuffle_matrix_sf_a is not None |
| ): |
| from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import ( |
| align_fp4_moe_weights_for_flashinfer_trtllm, |
| ) |
|
|
| |
| align_fp4_moe_weights_for_flashinfer_trtllm(layer) |
| |
| |
| layer.w13_blockscale_swizzled = layer.w13_weight_scale |
| layer.w2_blockscale_swizzled = layer.w2_weight_scale |
|
|
| else: |
| |
|
|
| if self._is_cutedsl_v2_standard and layer.moe_runner_config.is_gated: |
| |
| |
| |
| |
| from sglang.srt.layers.moe.moe_runner.flashinfer_cutedsl import ( |
| interleave_w13_halves, |
| ) |
|
|
| layer.w13_weight = Parameter( |
| interleave_w13_halves( |
| layer.w13_weight.view(torch.uint8), group_size=64, dim=1 |
| ).contiguous(), |
| requires_grad=False, |
| ) |
| layer.w13_weight_scale = Parameter( |
| interleave_w13_halves( |
| layer.w13_weight_scale, group_size=64, dim=1 |
| ).contiguous(), |
| requires_grad=False, |
| ) |
|
|
| |
| w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale) |
| alias_or_bind_derived_param( |
| layer, |
| "w13_weight_scale", |
| "w13_blockscale_swizzled", |
| w13_blockscale_swizzled, |
| ) |
|
|
| w13_weight = layer.w13_weight |
| intermediate_size_pad = w13_blockscale_swizzled.size(1) - w13_weight.size(1) |
| if intermediate_size_pad: |
| |
| |
| assert not layer.moe_runner_config.is_gated, ( |
| "The intermediate size required padding, " |
| "but padding is also implemented for gated activations" |
| ) |
|
|
| copy_or_rebind_param( |
| layer, |
| "w13_weight", |
| torch.nn.functional.pad( |
| w13_weight, (0, 0, 0, intermediate_size_pad) |
| ), |
| ) |
| copy_or_rebind_param( |
| layer, |
| "w2_weight", |
| torch.nn.functional.pad( |
| layer.w2_weight, (0, intermediate_size_pad // 2, 0, 0) |
| ), |
| ) |
| copy_or_rebind_param( |
| layer, |
| "w2_weight_scale", |
| torch.nn.functional.pad( |
| layer.w2_weight_scale, (0, intermediate_size_pad // 16) |
| ), |
| ) |
|
|
| |
| w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale) |
| alias_or_bind_derived_param( |
| layer, |
| "w2_weight_scale", |
| "w2_blockscale_swizzled", |
| w2_blockscale_swizzled, |
| ) |
|
|
| if self._is_cutedsl_v2_standard: |
| |
| |
| |
| from flashinfer.cute_dsl.utils import convert_sf_to_mma_layout |
|
|
| from sglang.srt.layers.moe.moe_runner.flashinfer_cutedsl import ( |
| _FP4_SF_VEC_SIZE, |
| ) |
|
|
| sf_vec_size = _FP4_SF_VEC_SIZE |
| num_local_experts = layer.w13_weight.shape[0] |
| w13_m = layer.w13_weight.shape[1] |
| w13_k = layer.w13_weight.shape[2] * 2 |
| w2_m = layer.w2_weight.shape[1] |
| w2_k = layer.w2_weight.shape[2] * 2 |
| layer.w13_blockscale_mma = Parameter( |
| convert_sf_to_mma_layout( |
| layer.w13_blockscale_swizzled.contiguous() |
| .view(torch.uint8) |
| .reshape(-1), |
| m=w13_m, |
| k=w13_k, |
| num_groups=num_local_experts, |
| sf_vec_size=sf_vec_size, |
| ), |
| requires_grad=False, |
| ) |
| layer.w2_blockscale_mma = Parameter( |
| convert_sf_to_mma_layout( |
| layer.w2_blockscale_swizzled.contiguous() |
| .view(torch.uint8) |
| .reshape(-1), |
| m=w2_m, |
| k=w2_k, |
| num_groups=num_local_experts, |
| sf_vec_size=sf_vec_size, |
| ), |
| requires_grad=False, |
| ) |
|
|
| |
|
|
| |
| device = layer.w13_weight.device |
| inter_size = layer.w2_weight.shape[2] * 2 |
| hidden_size = layer.w13_weight.shape[2] * 2 |
| existing_params = getattr(layer, "cutlass_moe_params", None) |
| if ( |
| existing_params is None |
| or existing_params.cutlass_moe_type != CutlassMoEType.BlockscaledFP4 |
| or existing_params.num_experts != layer.num_experts |
| or existing_params.intermediate_size_per_partition != inter_size |
| or existing_params.hidden_size != hidden_size |
| or existing_params.device != device |
| ): |
| layer.cutlass_moe_params = CutlassMoEParams( |
| CutlassMoEType.BlockscaledFP4, |
| device, |
| num_experts=layer.num_experts, |
| intermediate_size_per_partition=inter_size, |
| hidden_size=hidden_size, |
| ) |
|
|
| @property |
| def load_up_proj_weight_first(self) -> bool: |
| |
| |
| return self.moe_runner_config.is_gated and ( |
| self.enable_flashinfer_cutlass_moe or self._is_cutedsl_v2_standard |
| ) |
|
|
| def create_moe_runner( |
| self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig |
| ): |
| self.moe_runner_config = moe_runner_config |
| moe_runner_backend = get_moe_runner_backend() |
|
|
| if moe_runner_backend.is_auto(): |
| if is_cuda() and (8, 0) <= get_device_capability() < (10, 0): |
| moe_runner_backend = MoeRunnerBackend.MARLIN |
| else: |
| |
| |
| moe_runner_backend = MoeRunnerBackend.FLASHINFER_TRTLLM |
|
|
| self._moe_runner_backend = moe_runner_backend |
|
|
| if moe_runner_backend.is_flashinfer_cutedsl(): |
| import sglang.srt.layers.moe.moe_runner.flashinfer_cutedsl |
|
|
| if not moe_runner_backend.is_flashinfer_cutlass(): |
| self.runner = MoeRunner(moe_runner_backend, moe_runner_config) |
|
|
| def apply( |
| self, |
| layer: FusedMoE, |
| dispatch_output: StandardDispatchOutput, |
| ) -> CombineInput: |
| from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput |
|
|
| |
| |
| |
| |
| activation = self.moe_runner_config.activation |
| moe_runner_backend = getattr( |
| self, "_moe_runner_backend", get_moe_runner_backend() |
| ) |
|
|
| assert ( |
| activation in _SUPPORTED_ACT_STRS |
| ), f"{activation=} not in supported {_SUPPORTED_ACT_STRS}" |
| moe_runner_config = self.moe_runner_config |
|
|
| if moe_runner_backend.is_marlin(): |
| from sglang.srt.layers.moe.moe_runner.marlin import MarlinMoeQuantInfo |
|
|
| expert_map = None |
| global_num_experts = -1 |
| if hasattr(layer, "dispatcher") and hasattr( |
| layer.dispatcher, "local_expert_mapping" |
| ): |
| expert_map = layer.dispatcher.local_expert_mapping |
| if expert_map is not None: |
| global_num_experts = self.moe_runner_config.num_experts |
|
|
| quant_info = MarlinMoeQuantInfo( |
| w13_qweight=layer.w13_weight, |
| w2_qweight=layer.w2_weight, |
| w13_scales=layer.w13_weight_scale, |
| w2_scales=layer.w2_weight_scale, |
| w13_g_idx_sort_indices=None, |
| w2_g_idx_sort_indices=None, |
| weight_bits=4, |
| w13_global_scale=layer.w13_weight_scale_2, |
| w2_global_scale=layer.w2_weight_scale_2, |
| expert_map=expert_map, |
| global_num_experts=global_num_experts, |
| ) |
| return self.runner.run(dispatch_output, quant_info) |
|
|
| |
| if self.enable_flashinfer_trtllm_moe and hasattr(layer, "g1_scale_c"): |
| from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import ( |
| FlashInferTrtllmFp4MoeQuantInfo, |
| ) |
| from sglang.srt.layers.moe.utils import RoutingMethodType |
|
|
| |
| routing_method_type = getattr( |
| layer, "routing_method_type", RoutingMethodType.Default |
| ) |
|
|
| quant_info = FlashInferTrtllmFp4MoeQuantInfo( |
| w13_weight=layer.w13_weight.data, |
| w2_weight=layer.w2_weight.data, |
| w13_weight_scale=layer.w13_weight_scale.data, |
| w2_weight_scale=layer.w2_weight_scale.data, |
| g1_scale_c=layer.g1_scale_c.data, |
| g1_alphas=layer.g1_alphas.data, |
| g2_alphas=layer.g2_alphas.data, |
| w13_input_scale_quant=layer.w13_input_scale_quant, |
| global_num_experts=layer.num_experts, |
| local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, |
| local_num_experts=layer.num_local_experts, |
| intermediate_size_per_partition=layer.intermediate_size_per_partition, |
| routing_method_type=routing_method_type, |
| use_per_token_activation=self.quant_config.use_per_token_activation, |
| ) |
|
|
| return self.runner.run(dispatch_output, quant_info) |
|
|
| if self.enable_flashinfer_cutedsl_moe: |
| from sglang.srt.layers.moe.moe_runner.flashinfer_cutedsl import ( |
| CuteDslFp4MoeQuantInfo, |
| ensure_cutedsl_wrapper, |
| ) |
|
|
| if self._is_cutedsl_v1_deepep: |
| |
| |
| quant_info = CuteDslFp4MoeQuantInfo( |
| w13_weight=layer.w13_weight, |
| w2_weight=layer.w2_weight, |
| w13_weight_sf=layer.w13_blockscale_swizzled, |
| w2_weight_sf=layer.w2_blockscale_swizzled, |
| w1_alpha=layer.g1_alphas, |
| w2_alpha=layer.g2_alphas, |
| a1_scale=layer.w13_input_scale_quant, |
| a2_scale=layer.w2_input_scale_quant, |
| use_nvfp4_dispatch=MOE_NVFP4_DISPATCH, |
| down_gemm_overlap_args=getattr( |
| self.runner, "down_gemm_overlap_args", None |
| ), |
| ) |
| return self.runner.run(dispatch_output, quant_info) |
|
|
| |
| |
| ensure_cutedsl_wrapper(layer) |
| w1_alpha, fc2_input_scale, w2_alpha = layer._cutedsl_scales |
| quant_info = CuteDslFp4MoeQuantInfo( |
| w13_weight=layer.w13_weight, |
| w2_weight=layer.w2_weight, |
| w13_weight_sf=getattr( |
| layer, "w13_blockscale_mma", layer.w13_blockscale_swizzled |
| ), |
| w2_weight_sf=getattr( |
| layer, "w2_blockscale_mma", layer.w2_blockscale_swizzled |
| ), |
| w1_alpha=w1_alpha, |
| w2_alpha=w2_alpha, |
| a1_scale=layer._cutedsl_input_scale, |
| a2_scale=fc2_input_scale, |
| wrapper=layer._cutedsl_wrapper, |
| ) |
| return self.runner.run(dispatch_output, quant_info) |
|
|
| if self.enable_flashinfer_cutlass_moe: |
| from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import ( |
| get_activation_type, |
| ) |
| from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker |
|
|
| assert ( |
| not moe_runner_config.apply_router_weight_on_input |
| ), "apply_router_weight_on_input is not supported for Flashinfer" |
| |
| |
| fi_activation = ActivationType( |
| get_activation_type(activation, is_gated=moe_runner_config.is_gated) |
| ) |
| _CUTLASS_FP4_SUPPORTED = { |
| ActivationType.Swiglu, |
| ActivationType.Geglu, |
| ActivationType.Relu2, |
| ActivationType.Identity, |
| } |
| assert fi_activation in _CUTLASS_FP4_SUPPORTED, ( |
| f"Activation {activation!r} (is_gated={moe_runner_config.is_gated}) " |
| f"maps to {fi_activation.name}, which is not supported by the " |
| "flashinfer cutlass fp4 moe kernel." |
| ) |
| |
| |
| x = dispatch_output.hidden_states |
| x_sf = dispatch_output.hidden_states_scale |
| topk_output = dispatch_output.topk_output |
| topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids |
|
|
| output_dtype = torch.bfloat16 |
|
|
| if DispatchOutputChecker.format_is_flashinfer(dispatch_output): |
| symm_output = dispatch_output.moe_output |
| else: |
| |
| |
| output_col = x.shape[1] |
| if x_sf is not None and layer.moe_runner_config.is_gated: |
| output_col *= 2 |
| with use_symmetric_memory( |
| get_tp_group(), disabled=not is_allocation_symmetric() |
| ): |
| symm_output = torch.empty( |
| x.shape[0], |
| output_col, |
| dtype=output_dtype, |
| device=x.device, |
| ) |
|
|
| |
| |
| |
| |
| swiglu_kwargs = {} |
| _gemm1_alpha = moe_runner_config.gemm1_alpha |
| _gemm1_limit = moe_runner_config.gemm1_clamp_limit |
| if _gemm1_alpha is not None or _gemm1_limit is not None: |
| _num_local_experts = layer.w13_weight.shape[0] |
| swiglu_kwargs["swiglu_alpha"] = torch.full( |
| (_num_local_experts,), |
| _gemm1_alpha if _gemm1_alpha is not None else 1.0, |
| dtype=torch.float32, |
| device=x.device, |
| ) |
| swiglu_kwargs["swiglu_beta"] = torch.full( |
| (_num_local_experts,), |
| 1.0, |
| dtype=torch.float32, |
| device=x.device, |
| ) |
| if _gemm1_limit is not None: |
| swiglu_kwargs["swiglu_limit"] = torch.full( |
| (_num_local_experts,), |
| _gemm1_limit, |
| dtype=torch.float32, |
| device=x.device, |
| ) |
|
|
| output = flashinfer_cutlass_fused_moe( |
| output=symm_output, |
| input=x, |
| token_selected_experts=topk_ids.to(torch.int), |
| token_final_scales=topk_weights, |
| fc1_expert_weights=layer.w13_weight.view(torch.long), |
| fc2_expert_weights=layer.w2_weight.view(torch.long), |
| output_dtype=output_dtype, |
| input_sf=x_sf, |
| |
| quant_scales=[ |
| layer.w13_input_scale_quant, |
| layer.w13_blockscale_swizzled.view(torch.int32), |
| layer.g1_alphas, |
| layer.w2_input_scale_quant, |
| layer.w2_blockscale_swizzled.view(torch.int32), |
| layer.g2_alphas, |
| ], |
| ep_size=layer.moe_ep_size, |
| ep_rank=layer.moe_ep_rank, |
| tp_size=layer.moe_tp_size, |
| tp_rank=layer.moe_tp_rank, |
| tune_max_num_tokens=next_power_of_2(x.shape[0]), |
| activation_type=fi_activation, |
| enable_alltoall=get_moe_a2a_backend().is_flashinfer(), |
| **swiglu_kwargs, |
| )[0] |
|
|
| return StandardCombineInput(hidden_states=output) |
|
|
| from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 |
|
|
| x = dispatch_output.hidden_states |
| topk_output = dispatch_output.topk_output |
| topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids |
| output = cutlass_moe_fp4( |
| a=x, |
| a1_gscale=layer.w13_input_scale_quant, |
| w1_fp4=layer.w13_weight, |
| w1_blockscale=layer.w13_blockscale_swizzled, |
| w1_alphas=layer.g1_alphas, |
| a2_gscale=layer.w2_input_scale_quant, |
| w2_fp4=layer.w2_weight, |
| w2_blockscale=layer.w2_blockscale_swizzled, |
| w2_alphas=layer.g2_alphas, |
| topk_weights=topk_weights, |
| topk_ids=topk_ids, |
| params=layer.cutlass_moe_params, |
| apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input, |
| no_combine=moe_runner_config.no_combine, |
| ).to(x.dtype) |
| |
| return StandardCombineInput(hidden_states=output) |
|
|