MiniMax-M3-NVFP4 / sglang_patch /modelopt_quant.py
Mapika's picture
Upload folder using huggingface_hub
6684358 verified
Raw
History Blame Contribute Delete
103 kB
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
from __future__ import annotations
import logging
from enum import IntEnum
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import regex as re
import torch
from torch.nn.parameter import Parameter
from sglang.srt.distributed import get_tp_group
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.environ import envs
from sglang.srt.layers.dp_attention import is_allocation_symmetric
from sglang.srt.layers.moe import (
MoeRunner,
MoeRunnerBackend,
MoeRunnerConfig,
get_moe_a2a_backend,
get_moe_runner_backend,
)
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.moe.utils import (
is_flashinfer_cutedsl_v1_path,
should_use_flashinfer_cutlass_moe_fp4_allgather,
)
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
LinearMethodBase,
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp4_utils import (
fp4_quantize,
get_fp4_gemm_runner_backend,
)
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.layers.quantization.fp8_utils import (
apply_fp8_linear,
cutlass_fp8_supported,
is_blackwell_supported,
)
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
from sglang.srt.layers.quantization.marlin_utils_fp4 import (
apply_fp4_marlin_linear,
prepare_moe_nvfp4_layer_for_marlin,
prepare_nvfp4_layer_for_marlin,
)
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.utils import (
convert_to_channelwise,
is_layer_skipped,
per_tensor_dequantize,
requantize_with_max_scale,
swizzle_blockscale,
)
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.utils import alias_or_bind_derived_param, copy_or_rebind_param
from sglang.srt.utils.common import (
get_device_capability,
is_cuda,
is_sm120_supported,
next_power_of_2,
round_up,
)
from sglang.srt.utils.custom_op import register_custom_op
from sglang.srt.utils.patch_torch import register_fake_if_exists
if TYPE_CHECKING:
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
StandardDispatchOutput,
)
from sglang.srt.models.utils import WeightsMapper
try:
from flashinfer import mm_fp4 as flashinfer_fp4_gemm
from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_sf_a
enable_flashinfer_fp4_gemm = True
except ImportError:
enable_flashinfer_fp4_gemm = False
reorder_rows_for_gated_act_gemm = None
shuffle_matrix_a = None
shuffle_matrix_sf_a = None
if is_cuda():
try:
from sglang.jit_kernel.nvfp4 import cutlass_scaled_fp4_mm as cutlass_fp4_gemm
except ImportError:
cutlass_fp4_gemm = None
else:
cutlass_fp4_gemm = None
try:
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
from flashinfer.fused_moe.core import ActivationType
except ImportError:
flashinfer_cutlass_fused_moe = None
# Define a minimal ActivationType enum if flashinfer is not available
class ActivationType(IntEnum):
Swiglu = 3
Geglu = 4
Relu2 = 6
Identity = 7
# Initialize logger for the module
logger = logging.getLogger(__name__)
def _sglang_fp4_gemm_fake(
input: torch.Tensor,
weight: torch.Tensor,
input_sf: torch.Tensor,
weight_sf: torch.Tensor,
alpha: torch.Tensor,
out_dtype: torch.dtype,
out_features: int,
) -> torch.Tensor:
M = input.shape[-2]
N = int(out_features)
return input.new_empty((M, N), dtype=out_dtype)
@register_custom_op(fake_impl=_sglang_fp4_gemm_fake)
def fp4_gemm(
input: torch.Tensor,
weight: torch.Tensor,
input_sf: torch.Tensor,
weight_sf: torch.Tensor,
alpha: torch.Tensor,
out_dtype: torch.dtype,
out_features: int,
) -> torch.Tensor:
fp4_backend = get_fp4_gemm_runner_backend()
if fp4_backend.is_cutlass() and cutlass_fp4_gemm is not None:
# flashinfer.fp4_quantize returns scale factors as uint8 (e4m3fn bits
# stored in uint8 memory). The JIT kernel requires float8_e4m3fn dtype.
if input_sf.dtype != torch.float8_e4m3fn:
input_sf = input_sf.view(torch.float8_e4m3fn)
if weight_sf.dtype != torch.float8_e4m3fn:
weight_sf = weight_sf.view(torch.float8_e4m3fn)
return cutlass_fp4_gemm(input, weight, input_sf, weight_sf, alpha, out_dtype)
elif enable_flashinfer_fp4_gemm:
# Use the remapping logic to convert SGLang backend names to FlashInfer API names
backend = fp4_backend.get_flashinfer_backend()
return flashinfer_fp4_gemm(
input, weight, input_sf, weight_sf, alpha, out_dtype, backend=backend
)
else:
return cutlass_fp4_gemm(input, weight, input_sf, weight_sf, alpha, out_dtype)
if is_cuda() and (not is_sm120_supported()) and (fp4_quantize is not None):
@register_fake_if_exists("sgl_kernel::scaled_fp4_quant")
def _sgl_kernel_scaled_fp4_quant_fake(
output, input, output_scale, input_global_scale
):
return
# FP4 GEMM alignment constant - CUTLASS/FlashInfer kernels require dimensions divisible by 32
FP4_GEMM_ALIGNMENT = 32
def round_up_to_multiple(x: int, m: int) -> int:
"""Round up x to the nearest multiple of m."""
return (x + m - 1) // m * m
def pad_nvfp4_weight(
weight: torch.Tensor,
n_alignment: int = FP4_GEMM_ALIGNMENT,
k_alignment: int = FP4_GEMM_ALIGNMENT,
) -> tuple[torch.Tensor, int]:
"""
Pad packed NVFP4 weights to satisfy alignment constraints for FP4 GEMM kernels.
Different backends have different alignment requirements:
- CUTLASS/cuDNN: N % 32 == 0, K % 32 == 0
- TRTLLM: N % 128 == 0 (for shuffle_matrix_sf_a), K padding handled separately
Args:
weight: Packed FP4 weight tensor of shape [N, K//2] (2 FP4 values per byte)
n_alignment: Required alignment for N dimension (default 32, use 128 for TRTLLM)
k_alignment: Required alignment for K dimension (default 32, use 0 to skip)
Returns:
Tuple of (padded_weight, weights_padding_cols) where weights_padding_cols
is the number of columns added for K-dimension padding (in bytes).
"""
weight_current_rows = weight.shape[0] # N dimension
weight_current_col_bytes = weight.shape[1] # K//2 (packed)
# Calculate padding for N dimension (rows)
pad_rows = 0
if n_alignment > 0 and weight_current_rows % n_alignment != 0:
total_rows = round_up_to_multiple(weight_current_rows, n_alignment)
pad_rows = total_rows - weight_current_rows
# Calculate padding for K dimension (columns)
# 2 FP4 items are packed per byte in the input dimension
weight_current_col_elements = weight_current_col_bytes * 2
pad_cols_bytes = 0
if k_alignment > 0 and weight_current_col_elements % k_alignment != 0:
total_cols = round_up_to_multiple(weight_current_col_elements, k_alignment)
pad_cols = total_cols - weight_current_col_elements
# pad_cols is in elements, but padding is in bytes (2 elements per byte)
pad_cols_bytes = pad_cols // 2
# Apply padding in a single operation if needed
# For 2D tensor, pad argument is (pad_left, pad_right, pad_top, pad_bottom)
if pad_rows > 0 or pad_cols_bytes > 0:
weight = torch.nn.functional.pad(
weight, (0, pad_cols_bytes, 0, pad_rows)
).contiguous()
return weight, pad_cols_bytes
def pad_nvfp4_activation_for_cutlass(
x_fp4: torch.Tensor,
weights_padding_cols: int,
) -> torch.Tensor:
"""
Pad packed FP4 activations to match the K-dimension padding applied to weights.
Args:
x_fp4: Packed FP4 activation tensor
weights_padding_cols: Number of padding columns (in bytes) from weight padding
Returns:
Padded activation tensor
"""
if weights_padding_cols > 0:
return torch.nn.functional.pad(x_fp4, (0, weights_padding_cols)).contiguous()
return x_fp4
def slice_nvfp4_output(
out: torch.Tensor,
output_size: int,
) -> torch.Tensor:
"""
Slice the output tensor to remove padding in N dimension if weight was padded.
Args:
out: Output tensor from FP4 GEMM
output_size: Original output size before padding
Returns:
Sliced output tensor with padding removed
"""
if out.shape[-1] != output_size:
return out[..., :output_size].contiguous()
return out
# TODO make it true by default when the DeepEP PR is merged
MOE_NVFP4_DISPATCH = envs.SGLANG_MOE_NVFP4_DISPATCH.get()
# Supported activation schemes for the current configuration
ACTIVATION_SCHEMES = ["static"]
_SUPPORTED_ACT_STRS = ("silu", "relu2", "gelu")
class ModelOptQuantConfig(QuantizationConfig):
def __init__(
self,
kv_cache_quant_algo: Optional[str],
exclude_modules: Optional[List[str]],
packed_modules_mapping: Optional[Dict[str, List[str]]],
):
super().__init__()
self.packed_modules_mapping = packed_modules_mapping
self.exclude_modules = exclude_modules or []
self.kv_cache_quant_algo = kv_cache_quant_algo
self.use_per_token_activation = False
def _get_quant_method(
self,
layer: torch.nn.Module,
prefix: str,
*,
Linear: type[LinearMethodBase],
Moe: type[FusedMoEMethodBase],
) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, LinearBase):
if is_layer_skipped(
prefix, self.exclude_modules, self.packed_modules_mapping
) or self.is_layer_excluded(prefix):
return UnquantizedLinearMethod()
return Linear(self)
elif self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
return ModelOptFp8KVCacheMethod(self)
elif isinstance(layer, FusedMoE):
# Check if MoE layer should be excluded from quantization
# (e.g., MTP layers that have no quantization scales in checkpoint)
if self.is_layer_excluded(prefix):
# Falls back to default unquantized MoE
return None
return Moe(self)
return None
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["hf_quant_config.json"]
def get_scaled_act_names(self) -> List[str]:
return []
def apply_weight_name_mapper(
self, hf_to_sglang_mapper: WeightsMapper
): # noqa: B027
# Map excluded module patterns from HF layout to sglang layout.
# Ref: HF hf_quant_config.json for nvidia/Kimi-K2.5-NVFP4
# https://huggingface.co/nvidia/Kimi-K2.5-NVFP4/blob/main/hf_quant_config.json
if self.exclude_modules:
mapped = hf_to_sglang_mapper.apply_list(self.exclude_modules)
expanded: List[str] = []
for name in mapped:
expanded.append(name)
if name.startswith("language_model."):
expanded.append(name.removeprefix("language_model."))
# Preserve order, drop duplicates.
self.exclude_modules = list(dict.fromkeys(expanded))
def is_layer_excluded(self, prefix: str) -> bool:
"""Check if a layer should be excluded from quantization.
Handles:
- Exact matches (e.g., "lm_head" matching prefix "lm_head")
- Glob-style wildcards (e.g., "mtp*" matching "mtp_layers")
- Part-by-part matching (split prefix on "." and check each part)
- language_model. prefix stripping for vision-language models
- Fused module patterns (e.g., "q_a_proj" in "fused_qkv_a_proj_with_mqa")
"""
if not self.exclude_modules:
return False
# Build prefix variants: some models wrap layers under "language_model."
prefixes_to_check = [prefix]
if prefix.startswith("language_model."):
prefixes_to_check.append(prefix.removeprefix("language_model."))
# Fused module patterns: the exclude list may reference a sub-component
# (e.g., "q_a_proj") that is fused into a combined parameter name
# (e.g., "fused_qkv_a_proj_with_mqa"). We check if the last segment of
# the exclude pattern is a substring of the last segment of the prefix.
fused_patterns = {"q_a_proj", "q_b_proj", "kv_a_proj_with_mqa", "kv_b_proj"}
for pattern in self.exclude_modules:
# Convert glob-style wildcard to regex (e.g., "mtp*" -> "mtp.*")
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
for pfx in prefixes_to_check:
if re.fullmatch(regex_str, pfx):
return True
# Part-by-part check: handles wildcards like "mtp*" matching
pfx_parts = pfx.split(".")
for part in pfx_parts:
if re.fullmatch(regex_str, part):
return True
# Check fused patterns: if the last segment of the exclude pattern
# is a known fused component, check if it appears in the prefix's
# last segment (handles fused_qkv_a_proj_with_mqa containing q_a_proj)
pattern_tail = pattern.rsplit(".", maxsplit=1)[-1]
if pattern_tail in fused_patterns:
for pfx in prefixes_to_check:
if pattern_tail in pfx.rsplit(".", maxsplit=1)[-1]:
return True
return False
class ModelOptFp8Config(ModelOptQuantConfig):
"""Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks."""
def __init__(
self,
is_checkpoint_fp8_serialized: bool = False,
kv_cache_quant_method: Optional[str] = None,
exclude_modules: Optional[List[str]] = None,
packed_modules_mapping: Optional[Dict[str, List[str]]] = None,
) -> None:
"""
Args:
is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format.
"""
super().__init__(kv_cache_quant_method, exclude_modules, packed_modules_mapping)
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
if is_checkpoint_fp8_serialized:
logger.warning(
"Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
)
@classmethod
def override_quantization_method(cls, hf_quant_config, user_quant):
"""Override quantization method based on the model's config."""
return cls._modelopt_override_quantization_method(hf_quant_config, user_quant)
@classmethod
def get_name(cls) -> str:
return "modelopt_fp8"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
def get_min_capability(cls) -> int:
return 89 # Minimum hardware capability (e.g., Hopper GPUs).
@classmethod
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp8Config:
# Handle two different config formats:
# 1. hf_quant_config.json format: {"quantization": {"quant_algo": "FP8", ...}}
# 2. config.json quantization_config format: {"quant_algo": "FP8", ...}
# In future modelopt will deprecate hf_quant_config.json, and only keep config.json.
# For legacy reasons, we keep hf_quant_config.json for now.
# Initialize variables
kv_cache_quant_method = None
exclude_modules = None
# Try flat format first (config.json quantization_config - preferred format)
quant_method = config.get("quant_algo")
if quant_method is not None:
# Flat format (config.json quantization_config)
# Derive kv_cache quant from kv_cache_scheme dict
kv_cache_scheme = config.get("kv_cache_scheme")
if isinstance(kv_cache_scheme, dict):
if (
kv_cache_scheme.get("type") == "float"
and kv_cache_scheme.get("num_bits") == 8
):
kv_cache_quant_method = "FP8"
# Map 'ignore' field to 'exclude_modules'
exclude_modules = config.get("ignore")
else:
# Fall back to nested format (hf_quant_config.json - will be deprecated)
try:
quantization_section = cls.get_from_keys(config, ["quantization"])
quant_method = quantization_section.get("quant_algo")
kv_cache_quant_method = quantization_section.get("kv_cache_quant_algo")
exclude_modules = quantization_section.get("exclude_modules")
except ValueError:
raise ValueError(
"Cannot find 'quant_algo' in the model's quantization config. "
"Expected either flat format (config.json) or nested format (hf_quant_config.json)."
)
if quant_method is None:
raise ValueError(
"Cannot find 'quant_algo' in the model's quantization config. "
)
if "FP8" not in quant_method:
raise ValueError(
"ModelOptFp8Config only supports static FP8 quantization in SGLang. "
"For FP4 quantization, use ModelOptFp4Config. "
"Check the quantization config for your model's configuration."
)
return cls(
is_checkpoint_fp8_serialized=True,
kv_cache_quant_method=kv_cache_quant_method,
exclude_modules=exclude_modules,
packed_modules_mapping=config.get("packed_modules_mapping"),
)
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[QuantizeMethodBase]:
return self._get_quant_method(
layer, prefix, Linear=ModelOptFp8LinearMethod, Moe=ModelOptFp8MoEMethod
)
class ModelOptFp8LinearMethod(LinearMethodBase):
"""Linear method for ModelOpt static FP8 quantization.
Supports loading FP8 checkpoints with static weight and activation scales.
Future support may include dynamic scales.
**Limitations**:
1. Only supports per-tensor quantization due to `torch._scaled_mm` limitations.
2. Only supports the `float8_e4m3fn` data type.
Args:
quant_config (ModelOptFp8Config): The ModelOpt quantization configuration.
"""
def __init__(self, quant_config: ModelOptFp8Config):
super().__init__()
self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported()
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: Optional[int],
output_size: Optional[int],
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
"""Creates and registers weights, weight scales, and input scales for FP8 quantization."""
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
weight_dtype = (
torch.float8_e4m3fn
if self.quant_config.is_checkpoint_fp8_serialized
else params_dtype
)
# Set layer attributes
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
# Register weight
layer.register_parameter(
"weight",
ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=weight_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
),
)
if self.quant_config.is_checkpoint_fp8_serialized:
# Register weight and input scales
for scale_name in ["weight_scale", "input_scale"]:
layer.register_parameter(
scale_name,
PerTensorScaleParameter(
data=torch.full(
(len(output_partition_sizes),),
torch.finfo(torch.float32).min,
dtype=torch.float32,
),
weight_loader=weight_loader,
),
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Requantizes weights after loading using the maximum scale."""
max_w_scale, quantized_weight = requantize_with_max_scale(
layer.weight, layer.weight_scale, layer.logical_widths
)
layer.weight = Parameter(quantized_weight.t(), requires_grad=False)
# cutlass sgl-kernel only supports per-channel scale
if self.cutlass_fp8_supported:
max_w_scale = convert_to_channelwise(max_w_scale, layer.logical_widths)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Applies FP8 linear transformation."""
return apply_fp8_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported,
)
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
"""
Handles loading FP8 kv-cache scaling factors from modelopt quantized checkpoints.
"""
def __init__(self, quant_config: ModelOptFp8Config):
super().__init__(quant_config)
class ModelOptMixedPrecisionConfig(ModelOptQuantConfig):
"""Configuration for ModelOpt MIXED_PRECISION checkpoints."""
def __init__(
self,
kv_cache_quant_algo: Optional[str],
exclude_modules: Optional[List[str]],
packed_modules_mapping: Optional[Dict[str, List[str]]],
quantized_layers: Dict[str, Dict[str, Any]],
fp8_config: ModelOptFp8Config,
nvfp4_config: ModelOptFp4Config,
) -> None:
super().__init__(kv_cache_quant_algo, exclude_modules, packed_modules_mapping)
self.quantized_layers = quantized_layers
self.fp8_config = fp8_config
self.nvfp4_config = nvfp4_config
@classmethod
def override_quantization_method(cls, hf_quant_config, user_quant):
if hf_quant_config is None:
return None
if hf_quant_config.get("quant_method", "") == "modelopt_mixed":
return "modelopt_mixed"
return None
@classmethod
def get_name(cls) -> str:
return "modelopt_mixed"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
def get_min_capability(cls) -> int:
return ModelOptFp4Config.get_min_capability()
@classmethod
def from_config(cls, config: Dict[str, Any]) -> ModelOptMixedPrecisionConfig:
kv_cache_quant_algo = None
exclude_modules = None
quantized_layers = {}
quant_algo = config.get("quant_algo")
if quant_algo is not None:
kv_cache_scheme = config.get("kv_cache_scheme")
if isinstance(kv_cache_scheme, dict):
if (
kv_cache_scheme.get("type") == "float"
and kv_cache_scheme.get("num_bits") == 8
):
kv_cache_quant_algo = "FP8"
elif (
kv_cache_scheme.get("type") == "float"
and kv_cache_scheme.get("num_bits") == 4
):
kv_cache_quant_algo = "NVFP4"
else:
kv_cache_quant_algo = "auto"
exclude_modules = config.get("ignore")
quantized_layers = config.get("quantized_layers", {})
else:
quantization_section = cls.get_from_keys(config, ["quantization"])
quant_algo = quantization_section.get("quant_algo")
kv_cache_quant_algo = quantization_section.get("kv_cache_quant_algo")
exclude_modules = quantization_section.get("exclude_modules")
quantized_layers = quantization_section.get("quantized_layers", {})
if quant_algo != "MIXED_PRECISION":
raise ValueError(
"ModelOptMixedPrecisionConfig only supports MIXED_PRECISION checkpoints."
)
if not quantized_layers:
raise ValueError(
"MIXED_PRECISION quantization requires a non-empty quantized_layers map."
)
group_size = None
for layer_info in quantized_layers.values():
if layer_info.get("quant_algo", "").upper() == "NVFP4":
group_size = layer_info.get("group_size", 16)
break
if group_size is None:
group_size = 16
packed_modules_mapping = config.get("packed_modules_mapping")
fp8_config = ModelOptFp8Config(
is_checkpoint_fp8_serialized=True,
kv_cache_quant_method=kv_cache_quant_algo,
exclude_modules=[],
packed_modules_mapping=packed_modules_mapping,
)
nvfp4_config = ModelOptFp4Config(
is_checkpoint_nvfp4_serialized=True,
kv_cache_quant_algo=kv_cache_quant_algo,
exclude_modules=[],
packed_modules_mapping=packed_modules_mapping,
group_size=group_size,
)
return cls(
kv_cache_quant_algo=kv_cache_quant_algo,
exclude_modules=exclude_modules,
packed_modules_mapping=packed_modules_mapping,
quantized_layers=quantized_layers,
fp8_config=fp8_config,
nvfp4_config=nvfp4_config,
)
def apply_weight_name_mapper(self, hf_to_sglang_mapper: WeightsMapper):
super().apply_weight_name_mapper(hf_to_sglang_mapper)
if self.quantized_layers:
self.quantized_layers = hf_to_sglang_mapper.apply_dict(
self.quantized_layers
)
def _resolve_quant_algo(self, prefix: str) -> Optional[str]:
if prefix in self.quantized_layers:
return self.quantized_layers[prefix]["quant_algo"].upper()
proj_name = prefix.rsplit(".", 1)[-1]
if self.packed_modules_mapping and proj_name in self.packed_modules_mapping:
algos = set()
base = prefix.rsplit(".", 1)[0]
for shard_name in self.packed_modules_mapping[proj_name]:
shard_prefix = f"{base}.{shard_name}"
if shard_prefix in self.quantized_layers:
algos.add(self.quantized_layers[shard_prefix]["quant_algo"].upper())
if len(algos) == 1:
return algos.pop()
if len(algos) > 1:
raise ValueError(
f"Mixed quant_algo within fused layer {prefix}: {algos}. "
"All shards must use the same quantization."
)
prefix_dot = prefix + "."
for key, info in self.quantized_layers.items():
if key.startswith(prefix_dot):
return info["quant_algo"].upper()
return None
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
quant_algo = self._resolve_quant_algo(prefix)
if isinstance(layer, LinearBase):
if is_layer_skipped(
prefix, self.exclude_modules, self.packed_modules_mapping
) or self.is_layer_excluded(prefix):
return UnquantizedLinearMethod()
if quant_algo == "FP8":
return ModelOptFp8LinearMethod(self.fp8_config)
if quant_algo == "NVFP4":
return ModelOptFp4LinearMethod(self.nvfp4_config)
return UnquantizedLinearMethod()
if self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
return ModelOptFp8KVCacheMethod(self.fp8_config)
if isinstance(layer, FusedMoE):
if self.is_layer_excluded(prefix):
return None
if quant_algo == "FP8":
return ModelOptFp8MoEMethod(self.fp8_config)
if quant_algo == "NVFP4":
return ModelOptNvFp4FusedMoEMethod(self.nvfp4_config)
return None
return None
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
"""MoE method for ModelOpt FP8.
Supports loading FP8 checkpoints with static weight scale and activation scale.
Args:
quant_config: The ModelOpt quantization config.
"""
def __init__(self, quant_config: ModelOptFp8Config):
self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported()
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
# Use FP8 dtype if checkpoint is serialized, otherwise use the default dtype
weight_dtype = (
torch.float8_e4m3fn
if self.quant_config.is_checkpoint_fp8_serialized
else params_dtype
)
weight_loader = extra_weight_attrs.get("weight_loader")
num_shards = 2 if layer.moe_runner_config.is_gated else 1
intermediate_size = num_shards * intermediate_size_per_partition
w13_weight = ModelWeightParameter(
data=torch.empty(
num_experts,
intermediate_size,
hidden_size,
dtype=weight_dtype,
),
input_dim=2,
output_dim=1,
weight_loader=weight_loader,
)
layer.register_parameter("w13_weight", w13_weight)
w2_weight = ModelWeightParameter(
data=torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=weight_dtype,
),
input_dim=2,
output_dim=1,
weight_loader=weight_loader,
)
layer.register_parameter("w2_weight", w2_weight)
if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALES - Per-tensor scaling for ModelOpts
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_scale_shape = (num_experts, num_shards)
w13_weight_scale = PerTensorScaleParameter(
data=torch.full(
w13_scale_shape,
torch.finfo(torch.float32).min,
dtype=torch.float32,
),
weight_loader=weight_loader,
)
w2_weight_scale = PerTensorScaleParameter(
data=torch.full(
(num_experts,), torch.finfo(torch.float32).min, dtype=torch.float32
),
weight_loader=weight_loader,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Set weight loader attributes for scales
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
# INPUT SCALES - Per-tensor scaling for ModelOpt
w13_input_scale = PerTensorScaleParameter(
data=torch.full((num_experts,), 1.0, dtype=torch.float32),
weight_loader=weight_loader,
)
w2_input_scale = PerTensorScaleParameter(
data=torch.full((num_experts,), 1.0, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w13_input_scale", w13_input_scale)
layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Process FP8 MoE weights after loading from serialized checkpoint.
Only supports pre-quantized checkpoints with FP8 weights and scales.
"""
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
# Handle scale parameters
if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None:
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max of the w1 and w3 scales then dequant and requant each expert.
if layer.w13_weight_scale.dim() == 2: # Shape: (num_experts, 2)
# Get the maximum scale across w1 and w3 for each expert
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
# Requantize each expert's weights using the combined scale
# w13_weight has shape (num_experts, 2 * intermediate_size_per_partition, hidden_size)
# where the first intermediate_size_per_partition rows are w1, the next are w3
num_shards = 2 if layer.moe_runner_config.is_gated else 1
intermediate_size_per_partition = (
layer.w13_weight.shape[1] // num_shards
)
for expert_id in range(layer.w13_weight.shape[0]):
start = 0
for shard_id in range(num_shards): # (w1 and w3) or w13
# Dequantize using the original scale for this shard
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][
start : start + intermediate_size_per_partition, :
],
layer.w13_weight_scale[expert_id][shard_id],
)
# Requantize using the combined max scale
(
layer.w13_weight[expert_id][
start : start + intermediate_size_per_partition, :
],
_,
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
start += intermediate_size_per_partition
# Update the scale parameter to be per-expert instead of per-shard
layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
else:
layer.w13_weight_scale = Parameter(
layer.w13_weight_scale.data, requires_grad=False
)
if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None:
layer.w2_weight_scale = Parameter(
layer.w2_weight_scale.data, requires_grad=False
)
if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None:
layer.w13_input_scale = Parameter(
layer.w13_input_scale.max(), requires_grad=False
)
if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None:
layer.w2_input_scale = Parameter(
layer.w2_input_scale.max(), requires_grad=False
)
# Align FP8 weights to FlashInfer per-tensor kernel layout if enabled
if get_moe_runner_backend().is_flashinfer_trtllm():
from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import (
align_fp8_moe_weights_for_flashinfer_trtllm,
)
# ModelOpt FP8 stores weights in [Up, Gate] order, so we need to swap
align_fp8_moe_weights_for_flashinfer_trtllm(layer, swap_w13_halves=True)
elif get_moe_runner_backend().is_flashinfer_cutlass():
assert (
hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None
)
assert hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None
assert (
hasattr(layer, "w13_weight_scale")
and layer.w13_weight_scale is not None
)
assert (
hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None
)
input_scale = layer.w13_input_scale.to(torch.float32)
activation_scale = layer.w2_input_scale.to(torch.float32)
w13_weight_scale = layer.w13_weight_scale.to(torch.float32)
w2_weight_scale = layer.w2_weight_scale.to(torch.float32)
layer.fc1_dequant = Parameter(
w13_weight_scale * input_scale, requires_grad=False
)
layer.fc2_quant = Parameter(
activation_scale.reciprocal(), requires_grad=False
)
layer.fc2_dequant = Parameter(
activation_scale * w2_weight_scale, requires_grad=False
)
layer.fc1_input_dequant = Parameter(input_scale, requires_grad=False)
# flashinfer_cutlass kernel requires intermediate_size to be a
# multiple of 16. Pad weight tensors with zeros after loading.
# For gated activations (swiglu), w13 is [Up, Gate] concatenated
# along dim 1 — we must split, pad each half separately, and
# re-concat so the kernel's half-split stays aligned.
num_shards = 2 if layer.moe_runner_config.is_gated else 1
isp = layer.w13_weight.shape[1] // num_shards
if isp % 16 != 0:
pad_amount = round_up(isp, 16) - isp
w13_data = layer.w13_weight.data
if num_shards == 2:
up_weight = w13_data[:, :isp, :]
gate_weight = w13_data[:, isp:, :]
layer.w13_weight = Parameter(
torch.cat(
[
torch.nn.functional.pad(
up_weight, (0, 0, 0, pad_amount)
),
torch.nn.functional.pad(
gate_weight, (0, 0, 0, pad_amount)
),
],
dim=1,
),
requires_grad=False,
)
else:
layer.w13_weight = Parameter(
torch.nn.functional.pad(w13_data, (0, 0, 0, pad_amount)),
requires_grad=False,
)
layer.w2_weight = Parameter(
torch.nn.functional.pad(layer.w2_weight.data, (0, pad_amount)),
requires_grad=False,
)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
def apply(
self,
layer: torch.nn.Module,
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
from sglang.srt.layers.moe.topk import TopKOutputChecker
# Fast path: TRT-LLM FP8 per-tensor MoE using BYPASSED TopK routing
if (
get_moe_runner_backend().is_flashinfer_trtllm()
and TopKOutputChecker.format_is_bypassed(topk_output)
):
from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import (
FlashInferTrtllmFp8MoeQuantInfo,
fused_experts_none_to_flashinfer_trtllm_fp8,
)
from sglang.srt.layers.moe.utils import RoutingMethodType
topk_config = topk_output.topk_config
from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import (
get_activation_type,
)
_SUPPORTED_FP8_ACTIVATIONS = {"silu", "relu2"}
assert self.moe_runner_config.activation in _SUPPORTED_FP8_ACTIVATIONS, (
f"Only {_SUPPORTED_FP8_ACTIVATIONS} are supported for "
f"flashinfer trtllm fp8 moe, got '{self.moe_runner_config.activation}'"
)
routing_method_type = getattr(
layer, "routing_method_type", RoutingMethodType.Llama4
)
quant_info = FlashInferTrtllmFp8MoeQuantInfo(
w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight,
global_num_experts=layer.num_experts,
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
local_num_experts=layer.num_local_experts,
intermediate_size=layer.w2_weight.shape[2],
routing_method_type=routing_method_type,
block_quant=False,
w13_input_scale=layer.w13_input_scale,
output1_scales_scalar=layer.output1_scales_scalar,
output1_scales_gate_scalar=layer.output1_scales_gate_scalar,
output2_scales_scalar=layer.output2_scales_scalar,
use_routing_scales_on_input=True,
activation_type=get_activation_type(
self.moe_runner_config.activation,
is_gated=self.moe_runner_config.is_gated,
),
)
return fused_experts_none_to_flashinfer_trtllm_fp8(
dispatch_output, quant_info, self.moe_runner_config
)
if get_moe_runner_backend().is_flashinfer_cutlass():
from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import (
get_activation_type,
)
activation_str = self.moe_runner_config.activation
assert activation_str in _SUPPORTED_ACT_STRS, (
f"Activation {activation_str!r} is not supported for "
f"flashinfer cutlass fp8 moe (supported: {_SUPPORTED_ACT_STRS})."
)
activation = ActivationType(
get_activation_type(
activation_str, is_gated=self.moe_runner_config.is_gated
)
)
# FlashInfer CUTLASS MoE supports gated Swiglu/Geglu and non-gated
# Relu2/Identity. Non-gated Silu/Gelu are not implemented.
_CUTLASS_SUPPORTED = {
ActivationType.Swiglu,
ActivationType.Geglu,
ActivationType.Relu2,
ActivationType.Identity,
}
assert activation in _CUTLASS_SUPPORTED, (
f"Activation {activation_str!r} (is_gated="
f"{self.moe_runner_config.is_gated}) maps to {activation.name}, "
"which is not supported by flashinfer cutlass fp8 moe."
)
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
x_fp8, _ = scaled_fp8_quant(x, layer.w13_input_scale)
output_dtype = x.dtype
original_col = x.shape[1]
x_sf = None
with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
symm_output = torch.empty(
x.shape[0], original_col, dtype=output_dtype, device=x.device
)
output = flashinfer_cutlass_fused_moe(
output=symm_output,
input=x_fp8,
token_selected_experts=topk_ids.to(torch.int),
token_final_scales=topk_weights,
fc1_expert_weights=layer.w13_weight,
fc2_expert_weights=layer.w2_weight,
output_dtype=output_dtype,
input_sf=x_sf,
quant_scales=[
layer.fc1_dequant,
layer.fc2_quant,
layer.fc2_dequant,
layer.fc1_input_dequant,
],
ep_size=layer.moe_ep_size,
ep_rank=layer.moe_ep_rank,
tp_size=layer.moe_tp_size,
tp_rank=layer.moe_tp_rank,
tune_max_num_tokens=next_power_of_2(x.shape[0]),
activation_type=activation,
)[0]
if (
not layer.should_fuse_routed_scaling_factor_in_topk
and self.moe_runner_config.routed_scaling_factor is not None
):
output.mul_(self.moe_runner_config.routed_scaling_factor)
return StandardCombineInput(hidden_states=output)
quant_info = TritonMoeQuantInfo(
w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight,
use_fp8_w8a8=True,
per_channel_quant=False,
w13_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a13_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
return self.runner.run(dispatch_output, quant_info)
class ModelOptFp4Config(ModelOptQuantConfig):
"""Config class for FP4."""
def __init__(
self,
is_checkpoint_nvfp4_serialized: bool = False,
kv_cache_quant_algo: str = None,
group_size: int = None,
exclude_modules: List[str] = None,
packed_modules_mapping: Optional[Dict[str, List[str]]] = None,
use_per_token_activation: Optional[bool] = None,
) -> None:
super().__init__(kv_cache_quant_algo, exclude_modules, packed_modules_mapping)
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
if is_checkpoint_nvfp4_serialized:
logger.warning(
"Detected nvfp4 checkpoint. Please note that the "
"format is experimental and subject to change."
)
self.group_size = group_size
self.use_per_token_activation = (
use_per_token_activation
or envs.SGLANG_FLASHINFER_NVFP4_PER_TOKEN_ACTIVATION.get()
)
@classmethod
def override_quantization_method(cls, hf_quant_config, user_quant):
"""Override quantization method based on the model's config."""
return cls._modelopt_override_quantization_method(hf_quant_config, user_quant)
@classmethod
def get_name(cls) -> str:
return "modelopt_fp4"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16, torch.half, torch.float8_e4m3fn]
@classmethod
def get_min_capability(cls) -> int:
return 80
@staticmethod
def common_group_size(cfg: dict) -> int:
"""Return the unique group_size across the config; raise if missing/mismatched."""
sizes = set()
# Top-level and 'quantization' block
v = cfg.get("group_size")
if isinstance(v, int):
sizes.add(v)
q = cfg.get("quantization")
if isinstance(q, dict):
v = q.get("group_size")
if isinstance(v, int):
sizes.add(v)
# config_groups: accept group-level or nested dicts (e.g., weights/input_activations)
for g in (cfg.get("config_groups") or {}).values():
if isinstance(g, dict):
v = g.get("group_size")
if isinstance(v, int):
sizes.add(v)
for sub in g.values():
if isinstance(sub, dict):
v = sub.get("group_size")
if isinstance(v, int):
sizes.add(v)
if not sizes:
raise ValueError("No group_size found in config.")
if len(sizes) > 1:
raise ValueError(f"Inconsistent group_size values: {sorted(sizes)}")
return next(iter(sizes))
@classmethod
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
# Handle two different config formats:
# 1. hf_quant_config.json format: {"quantization": {"quant_algo": "NVFP4", ...}}
# 2. config.json quantization_config format: {"quant_algo": "NVFP4", ...}
# In future modelopt will deprecate hf_quant_config.json, and only keep config.json.
# For legacy reasons, we keep hf_quant_config.json for now.
# Initialize variables
kv_cache_quant_algo = None
group_size = None
exclude_modules = []
# Try flat format first (config.json quantization_config - preferred format)
quant_method = config.get("quant_algo")
if quant_method is not None:
# Flat format (config.json quantization_config)
# Derive kv_cache_quant_algo from kv_cache_scheme dict
kv_cache_scheme = config.get("kv_cache_scheme")
if isinstance(kv_cache_scheme, dict):
if (
kv_cache_scheme.get("type") == "float"
and kv_cache_scheme.get("num_bits") == 8
):
kv_cache_quant_algo = "FP8"
else:
kv_cache_quant_algo = "auto"
elif isinstance(kv_cache_scheme, str):
scheme_name = kv_cache_scheme.strip().upper()
if scheme_name in ("FP8", "FLOAT8"):
kv_cache_quant_algo = "FP8"
elif scheme_name in ("FP4", "FLOAT4", "NVFP4"):
kv_cache_quant_algo = "NVFP4"
else:
kv_cache_quant_algo = "auto"
else:
kv_cache_quant_algo = "auto"
group_size = config.get("group_size")
# If group_size is not at top level, try to extract from config_groups
if group_size is None:
config_groups = config.get("config_groups", {})
if config_groups:
# Get group_size from the first group's weights config
first_group = next(iter(config_groups.values()), {})
weights_config = first_group.get("weights", {})
group_size = weights_config.get("group_size")
exclude_modules = config.get("ignore", [])
else:
# Fall back to nested format (hf_quant_config.json - legacy format)
try:
quant_config = cls.get_from_keys(config, ["quantization"])
quant_method = quant_config["quant_algo"]
kv_cache_quant_algo = quant_config.get("kv_cache_quant_algo")
if not kv_cache_quant_algo:
kv_cache_quant_algo = "auto"
group_size = ModelOptFp4Config.common_group_size(config)
exclude_modules = quant_config.get("exclude_modules", [])
except (ValueError, KeyError):
raise ValueError(
"Cannot find 'quant_algo' in the model's quantization config. "
"Expected either flat format (config.json) or nested format (hf_quant_config.json)."
)
if not quant_method in ["FP8", "NVFP4"]:
raise ValueError(
f"ModelOpt currently only supports: FP8, NVFP4"
" quantizations in sglang. Please check the "
"quantization config for your model's configuration."
)
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
if group_size is None or exclude_modules is None:
logger.warning(
f"group_size: {group_size},"
f"kv_cache_quant_algo: {kv_cache_quant_algo},"
f"exclude_modules: {exclude_modules}"
)
raise ValueError(
"NVFP4 quantization requires group_size and exclude_modules "
"specified in the quantization config"
)
return cls(
is_checkpoint_nvfp4_serialized,
kv_cache_quant_algo,
group_size,
exclude_modules,
config.get("packed_modules_mapping"),
)
def get_quant_method(self, layer: torch.nn.Module, prefix: str):
return self._get_quant_method(
layer,
prefix,
Linear=ModelOptFp4LinearMethod,
Moe=ModelOptNvFp4FusedMoEMethod,
)
class ModelOptFp4LinearMethod(LinearMethodBase):
"""Linear method for NVFP4.
Supports loading NVFP4 checkpoints with the following structure:
|Tensor Name | datatype | shape |
|----------------------------------------------------|
|input_scale | torch.float32 | scalar |
|weight | NVFP4(SE2M1) | [1, X, y/2] |
|weight_scale | FP8-E4M3 | [X, Y] |
|weight_scale_2 | torch.float32 | scalar |
The weights are quantized per block of 16 elements.
Args: quant_config: The ModelOpt quantization config.
"""
def __init__(self, quant_config: ModelOptFp4Config):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del input_size, output_size
if not self.quant_config.is_checkpoint_nvfp4_serialized:
raise ValueError(
"NVFP4 quantization was selected, "
" dynamic quantization is not supported."
)
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.params_dtype = params_dtype
layer.quant_config = self.quant_config
if input_size_per_partition % 16 != 0:
raise ValueError(
"Unsupported model when in features size is not multiple of 16"
)
weight_dtype = (
torch.float8_e4m3fn
if self.quant_config.is_checkpoint_nvfp4_serialized
else params_dtype
)
weight = ModelWeightParameter(
data=torch.empty(
# 2 fp4 data is packed in one uint8 in the input dimension
output_size_per_partition,
input_size_per_partition // 2,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
input_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("input_scale", input_scale)
weight_scale_2 = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale_2", weight_scale_2)
weight_scale = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // self.quant_config.group_size,
dtype=weight_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
input_scale_2 = layer.input_scale.max().to(torch.float32)
weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
# alpha / input_scale_inv stay as scalar Parameters. Aliasing them into
# the [N_partitions] source slot breaks fused-QKV linears whose
# downstream kernels assume scalar input scale.
copy_or_rebind_param(
layer, "alpha", (input_scale_2 * weight_scale_2).to(torch.float32)
)
copy_or_rebind_param(
layer, "input_scale_inv", (1 / input_scale_2).to(torch.float32)
)
# Store original output size before any padding
layer.output_size_per_partition = layer.weight.shape[0]
if get_fp4_gemm_runner_backend().is_marlin():
if self.quant_config.group_size != 16:
raise ValueError(
f"NVFP4 Marlin requires group_size=16, got {self.quant_config.group_size}."
)
copy_or_rebind_param(layer, "input_global_scale", input_scale_2)
copy_or_rebind_param(layer, "weight_global_scale", weight_scale_2)
prepare_nvfp4_layer_for_marlin(layer)
layer.weights_padding_cols = 0
return
if not is_blackwell_supported():
raise ValueError(
"ModelOpt NVFP4 native dense GEMM backends require SM100+. "
"Use --fp4-gemm-backend marlin on SM80-SM90."
)
if get_fp4_gemm_runner_backend().is_flashinfer_trtllm():
# FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
# FlashInfer provides nvfp4_quantize to quantize + shuffle the
# layout but we use our own quantization so we have to call
# shuffles ourselves.
#
# Alignment requirements:
# - shuffle_matrix_a: weight.shape[0] (N) % 32 == 0
# - shuffle_matrix_sf_a: scale.shape[0] (N) % 128 == 0, scale.shape[1] (K/16) % 4 == 0
# We pad N to multiple of 128 and K/16 to multiple of 4.
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
# Pad weight N dimension to 128
weight, _ = pad_nvfp4_weight(
layer.weight.data, n_alignment=128, k_alignment=0
)
# Pad scale N dimension to match weight
scale = layer.weight_scale
if scale.shape[0] != weight.shape[0]:
pad_n = weight.shape[0] - scale.shape[0]
scale = torch.nn.functional.pad(scale, (0, 0, 0, pad_n))
# Pad K dimension: scale K/16 must be multiple of 4
scale_k = scale.shape[1] # K/16
weights_padding_cols = 0
if scale_k % 4 != 0:
padded_scale_k = round_up_to_multiple(scale_k, 4)
pad_scale_k = padded_scale_k - scale_k
# Pad scale K/16 dimension
scale = torch.nn.functional.pad(scale, (0, pad_scale_k, 0, 0))
# Pad weight K/2 dimension correspondingly (K/2 = K/16 * 8)
pad_weight_k = pad_scale_k * 8
weight = torch.nn.functional.pad(weight, (0, pad_weight_k, 0, 0))
# Store K padding for activation padding in apply()
weights_padding_cols = pad_weight_k
# Shuffle for TRTLLM layout
epilogue_tile_m = 128
shuffled_scale_shape = scale.shape
weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
scale = (
shuffle_matrix_sf_a(scale.view(torch.uint8), epilogue_tile_m)
.reshape(shuffled_scale_shape)
.view(torch.float8_e4m3fn)
)
alias_or_bind_derived_param(
layer, "weight_scale", "weight_scale_interleaved", scale
)
copy_or_rebind_param(layer, "weight", weight)
layer.weights_padding_cols = weights_padding_cols
return
# Pad weights for CUTLASS/FlashInfer kernel alignment (K and N divisible by 32)
weight, weights_padding_cols = pad_nvfp4_weight(layer.weight.data)
layer.weights_padding_cols = weights_padding_cols
copy_or_rebind_param(layer, "weight", weight)
# Pad and blockwise interleave weight_scale
scales = layer.weight_scale
scale_ndim = scales.ndim
if scale_ndim == 2:
scales = scales.unsqueeze(0)
assert scales.ndim == 3
B, M, K = scales.shape
M_padded = round_up_to_multiple(M, 128)
K_padded = round_up_to_multiple(K, 4)
padded_scales = torch.zeros((B, M_padded, K_padded), dtype=scales.dtype)
padded_scales[:B, :M, :K] = scales
# Snapshot the raw (pre-swizzle) scale BEFORE alias_or_bind_derived_param
# overwrites layer.weight_scale.data in-place via .copy_() on the broadcast
# path. Without this, the swiglu side-channel below would read the swizzled
# bytes when it later re-reads layer.weight_scale.
raw_scale_snapshot = (
(scales.squeeze(0) if scale_ndim == 2 else scales).detach().clone()
)
batches, rows, cols = padded_scales.shape
assert rows % 128 == 0
assert cols % 4 == 0
padded_scales = padded_scales.reshape(batches, rows // 128, 4, 32, cols // 4, 4)
padded_scales = padded_scales.permute((0, 1, 4, 3, 2, 5))
padded_scales = padded_scales.contiguous().cuda()
padded_scales = (
padded_scales.reshape(M_padded, K_padded)
if scale_ndim == 2
else padded_scales.reshape(B, M_padded, K_padded)
)
alias_or_bind_derived_param(
layer, "weight_scale", "weight_scale_interleaved", padded_scales
)
if getattr(layer, "_interleave_for_swiglu_fusion", False):
from sglang.srt.layers.quantization.nvfp4_gemm_swiglu_nvfp4_quant import (
interleave_linear_and_gate,
swizzle_blockscale_2d,
)
w = layer.weight.data
assert weights_padding_cols == 0, (
"_interleave_for_swiglu_fusion does not support K-padded weights; "
f"got weights_padding_cols={weights_padding_cols}."
)
assert raw_scale_snapshot.shape[0] == w.shape[0], (
"_interleave_for_swiglu_fusion requires no N-padding; "
f"raw_scale rows={raw_scale_snapshot.shape[0]} vs weight rows={w.shape[0]}."
)
assert w.shape[0] % 128 == 0, (
"_interleave_for_swiglu_fusion requires N % 128 == 0 (group_size=64 "
f"with gate+up halves); got N={w.shape[0]}."
)
gate_w, up_w = w.chunk(2, dim=0)
w_swiglu = interleave_linear_and_gate(
torch.cat((up_w, gate_w), dim=0), group_size=64, dim=0
)
gate_s, up_s = raw_scale_snapshot.chunk(2, dim=0)
w_scale_swiglu = swizzle_blockscale_2d(
interleave_linear_and_gate(
torch.cat((up_s, gate_s), dim=0), group_size=64, dim=0
)
)
layer.weight_swiglu_interleaved = w_swiglu
layer.weight_scale_swiglu_interleaved = w_scale_swiglu
# Keep the Parameter objects alive so weight reload can refill
# them and re-run this hook; free their storage in the meantime.
layer.weight.data = torch.empty(
0, dtype=layer.weight.dtype, device=layer.weight.device
)
layer.weight_scale_interleaved.data = torch.empty(
0,
dtype=layer.weight_scale_interleaved.dtype,
device=layer.weight_scale_interleaved.device,
)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if get_fp4_gemm_runner_backend().is_marlin():
return apply_fp4_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
weight_global_scale=layer.weight_global_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias,
)
# `_accepts_prequantized_fp4` is the explicit opt-in so an accidental
# tuple from unrelated code can't silently bypass quantization.
if getattr(layer, "_accepts_prequantized_fp4", False) and isinstance(x, tuple):
x_fp4, x_scale_interleaved = x
x_m = x_fp4.shape[0]
output_dtype = layer.params_dtype
else:
x_fp4, x_scale_interleaved = fp4_quantize(x, layer.input_scale_inv)
x_m, _ = x.shape
output_dtype = x.dtype
output_size = layer.output_size_per_partition
w_n, _ = layer.weight.shape
output_shape = [x_m, output_size]
assert x_fp4.dtype == torch.uint8
assert layer.weight.dtype == torch.uint8
assert layer.weight_scale_interleaved.dtype == torch.float8_e4m3fn
assert layer.alpha.dtype == torch.float32
# Pad activations to match weight K-dimension padding
weights_padding_cols = getattr(layer, "weights_padding_cols", 0)
x_fp4 = pad_nvfp4_activation_for_cutlass(x_fp4, weights_padding_cols)
w = layer.weight
w_scale_interleaved = layer.weight_scale_interleaved
if (
enable_flashinfer_fp4_gemm
and not get_fp4_gemm_runner_backend().is_cutlass()
):
w = layer.weight.T
w_scale_interleaved = layer.weight_scale_interleaved.T
out = fp4_gemm(
x_fp4,
w,
x_scale_interleaved,
w_scale_interleaved,
layer.alpha,
output_dtype,
w_n,
)
# Slice output to remove N-dimension padding
out = slice_nvfp4_output(out, output_size)
if bias is not None:
out = out + bias
return out.view(*output_shape)
class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
"""
MoE Method for FP4 Quantization with Blockscales and PerTensorScales
Args:
quant_config: NVFP4 Quant Config
"""
def __init__(self, quant_config: ModelOptFp4Config):
self.quant_config = quant_config
moe_runner_backend = get_moe_runner_backend()
if moe_runner_backend.is_auto() and is_cuda():
capability = get_device_capability()
use_marlin_fallback = (8, 0) <= capability < (10, 0)
else:
use_marlin_fallback = moe_runner_backend.is_marlin()
if not is_blackwell_supported() and not use_marlin_fallback:
raise ValueError(
"Current platform does not support NVFP4"
" quantization with the selected MoE backend. Please use "
"Blackwell and above, or use moe_runner_backend=marlin on SM80+."
)
self.enable_flashinfer_trtllm_moe = (
get_moe_runner_backend().is_flashinfer_trtllm()
or get_moe_runner_backend().is_flashinfer_trtllm_routed()
)
self._cache_permute_indices = {}
@property
def enable_flashinfer_cutlass_moe(self) -> bool:
from sglang.srt.layers.moe import get_moe_runner_backend
"""Access the global enable_flashinfer_cutlass_moe setting."""
return get_moe_runner_backend().is_flashinfer_cutlass()
@property
def enable_flashinfer_cutedsl_moe(self) -> bool:
"""Access the global enable_flashinfer_cutedsl_moe setting."""
from sglang.srt.layers.moe import get_moe_runner_backend
return get_moe_runner_backend().is_flashinfer_cutedsl()
# ----- CuteDSL v1 vs v2 path helpers -----
#
# "v1": cutedsl + deepep low-latency.
# - MoeRunner fused func calls flashinfer_cutedsl_moe_masked
# (grouped_gemm_nt_masked).
# - Expects W13 in default [Gate, Up] order, NOT interleaved.
# - Uses swizzled blockscales directly (w13_blockscale_swizzled).
#
# "v2" (standard): cutedsl + none/flashinfer a2a.
# - MoeRunner fused func calls CuteDslMoEWrapper kernels.
# - Expects W13 in [Up, Gate] order, interleaved in 64-row chunks.
# - Uses MMA-layout blockscales (w13_blockscale_mma).
@property
def _is_cutedsl_v1_deepep(self) -> bool:
"""CuteDSL v1 + DeepEP low-latency path (masked grouped GEMM)."""
return is_flashinfer_cutedsl_v1_path()
@property
def _is_cutedsl_v2_standard(self) -> bool:
"""CuteDSL v2 standard path (a2a=none or flashinfer, uses CuteDslMoEWrapper)."""
return self.enable_flashinfer_cutedsl_moe and not self._is_cutedsl_v1_deepep
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
is_nvfp4_online = getattr(self.quant_config, "is_nvfp4_online", False)
if not self.quant_config.is_checkpoint_nvfp4_serialized and not is_nvfp4_online:
raise ValueError(
"NVFP4 quantization was selected, "
" dynamic quantization is not supported."
)
# `nvfp4_online` is not a serialized checkpoint format, but after the
# online loader converts each expert it uses the same packed NVFP4
# weights, block scales, and per-tensor scales as serialized ModelOpt
# NVFP4 checkpoints. Reuse this layout and swap only the weight loader.
if is_nvfp4_online:
if not self.enable_flashinfer_trtllm_moe:
raise ValueError(
"--quantization nvfp4_online supports only "
"--moe-runner-backend flashinfer_trtllm or "
"flashinfer_trtllm_routed."
)
# TODO(ch-wan): check if this is needed
layer.intermediate_size_per_partition = intermediate_size_per_partition
layer.params_dtype = params_dtype
layer.quant_config = self.quant_config
weight_dtype = torch.uint8
weight_scale_dtype = torch.float8_e4m3fn
weight_loader = extra_weight_attrs.get("weight_loader")
if is_nvfp4_online:
weight_loader = self.get_online_weight_loader(layer, weight_loader)
# GEMM 1
num_shards = 2 if layer.moe_runner_config.is_gated else 1
w13_weight = ModelWeightParameter(
data=torch.empty(
layer.num_local_experts,
num_shards * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension
hidden_size // 2,
dtype=weight_dtype,
),
input_dim=1,
output_dim=2,
weight_loader=weight_loader,
)
layer.register_parameter("w13_weight", w13_weight)
# GEMM 2
w2_weight = ModelWeightParameter(
data=torch.empty(
layer.num_local_experts,
hidden_size,
# 2 fp4 items are packed in the input dimension
intermediate_size_per_partition // 2,
dtype=weight_dtype,
),
input_dim=1,
output_dim=2,
weight_loader=weight_loader,
)
layer.register_parameter("w2_weight", w2_weight)
w13_weight_scale = ModelWeightParameter(
data=torch.empty(
layer.num_local_experts,
num_shards * intermediate_size_per_partition,
hidden_size // self.quant_config.group_size,
dtype=weight_scale_dtype,
),
input_dim=1,
output_dim=2,
weight_loader=weight_loader,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
# TRTLLM replaces blockscale_swizzled with an alias to weight_scale
# during process_weights_after_loading, so skip the expensive
# swizzle+allocate here to avoid GPU memory fragmentation
if self.enable_flashinfer_trtllm_moe:
layer.w13_blockscale_swizzled = None
else:
layer.w13_blockscale_swizzled = Parameter(
swizzle_blockscale(layer.w13_weight_scale), requires_grad=False
)
w2_weight_scale = ModelWeightParameter(
data=torch.empty(
layer.num_local_experts,
hidden_size,
intermediate_size_per_partition // self.quant_config.group_size,
dtype=weight_scale_dtype,
),
input_dim=1,
output_dim=2,
weight_loader=weight_loader,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
if self.enable_flashinfer_trtllm_moe:
layer.w2_blockscale_swizzled = None
else:
layer.w2_blockscale_swizzled = Parameter(
swizzle_blockscale(layer.w2_weight_scale), requires_grad=False
)
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
)
w13_weight_scale_shape = (
(layer.num_local_experts, 2)
if layer.moe_runner_config.is_gated
else (layer.num_local_experts,)
)
w13_weight_scale_2 = PerTensorScaleParameter(
data=torch.empty(w13_weight_scale_shape, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)
w2_weight_scale_2 = PerTensorScaleParameter(
data=torch.empty(layer.num_local_experts, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)
if is_nvfp4_online and self.quant_config.is_checkpoint_fp8_serialized:
# FP8 checkpoints usually store expert scales as weight_scale_inv.
# Online NVFP4 consumes them in the loader and writes the generated
# NVFP4 scales into w*_weight_scale / w*_weight_scale_2 instead.
w13_source_weight_scale_inv = PerTensorScaleParameter(
data=torch.empty(0, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter(
"w13_weight_scale_inv", w13_source_weight_scale_inv
)
w2_source_weight_scale_inv = PerTensorScaleParameter(
data=torch.empty(0, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w2_weight_scale_inv", w2_source_weight_scale_inv)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
w13_input_scale_shape = (layer.num_experts, num_shards)
w13_input_scale = PerTensorScaleParameter(
data=torch.empty(w13_input_scale_shape, dtype=torch.float32),
weight_loader=weight_loader,
)
w13_input_scale._sglang_require_global_experts = True
layer.register_parameter("w13_input_scale", w13_input_scale)
w2_input_scale = PerTensorScaleParameter(
data=torch.empty(layer.num_experts, dtype=torch.float32),
weight_loader=weight_loader,
)
w2_input_scale._sglang_require_global_experts = True
layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Process FP4 MoE weights after loading from serialized checkpoint.
Only supports pre-quantized checkpoints with FP8 weights and scales.
"""
# GEMM 1 scale processing
if layer.moe_runner_config.is_gated:
if layer.w13_weight_scale_2.dim() == 1:
# Some checkpoints store a shared scale for w1/w3.
w13_weight_scale_2 = layer.w13_weight_scale_2
else:
if layer.w13_weight_scale_2.shape[1] >= 2 and not torch.allclose(
layer.w13_weight_scale_2[:, 0],
layer.w13_weight_scale_2[:, 1],
):
logger.warning_once(
"w1_weight_scale_2 must match w3_weight_scale_2. "
"Accuracy may be affected."
)
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
else:
w13_weight_scale_2 = layer.w13_weight_scale_2[:]
moe_runner_backend = getattr(
self, "_moe_runner_backend", get_moe_runner_backend()
)
if moe_runner_backend.is_marlin():
copy_or_rebind_param(
layer,
"w13_weight_scale_2",
w13_weight_scale_2.contiguous(),
)
prepare_moe_nvfp4_layer_for_marlin(layer)
return
# Calculate input scales based on strategy
if self.enable_flashinfer_cutlass_moe or self.enable_flashinfer_trtllm_moe:
w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
elif self.enable_flashinfer_cutedsl_moe:
# CuteDSL standard path uses a single scalar input scale (all experts).
w13_input_scale = (
layer.w13_input_scale.max()
.to(torch.float32)
.repeat(layer.w13_input_scale.shape[0])
)
w2_input_scale = layer.w2_input_scale
def _slice_scale(w):
assert w.shape == (layer.num_experts,)
assert layer.moe_ep_size * layer.num_local_experts == layer.num_experts
return w[
layer.moe_ep_rank
* layer.num_local_experts : (layer.moe_ep_rank + 1)
* layer.num_local_experts
]
w13_input_scale = _slice_scale(w13_input_scale)
w2_input_scale = _slice_scale(w2_input_scale)
if MOE_NVFP4_DISPATCH:
assert torch.all(w13_input_scale == w13_input_scale[0])
w13_input_scale = w13_input_scale[0]
else:
w13_input_scale = layer.w13_input_scale.max(dim=-1).values.to(torch.float32)
w2_input_scale = layer.w2_input_scale
if self.quant_config.use_per_token_activation:
# FlashInfer computes activation scales dynamically per token, so
# the static checkpoint activation scale is intentionally neutral.
w13_input_scale = torch.ones_like(w13_input_scale, dtype=torch.float32)
w2_input_scale = torch.ones_like(w2_input_scale, dtype=torch.float32)
# Create shared parameters
copy_or_rebind_param(
layer,
"g1_alphas",
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
)
copy_or_rebind_param(
layer,
"g2_alphas",
(w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
)
copy_or_rebind_param(
layer,
"w13_input_scale_quant",
(1 / w13_input_scale).to(torch.float32),
)
copy_or_rebind_param(
layer,
"w2_input_scale_quant",
(1 / w2_input_scale).to(torch.float32),
)
# TODO: for flashinfer always do MOE_NVFP4_DISPATCH
layer.dispatcher.set_quant_config(
{
"input_global_scale": (
layer.w13_input_scale_quant
if MOE_NVFP4_DISPATCH
or should_use_flashinfer_cutlass_moe_fp4_allgather()
else None
)
}
)
block_size = 16
# Validate weight scales
assert_dim = 2 if layer.moe_runner_config.is_gated else 1
for name, weight_scale in [
("w13", layer.w13_weight_scale),
("w2", layer.w2_weight_scale),
]:
# For NVFP4 TRTLLM we require one scale per 16 inputs (last dim == expected_blocks[name]).
if get_moe_runner_backend().is_flashinfer_trtllm():
expected_blocks = {
"w13": layer.w13_weight.shape[2] * 2 // block_size,
"w2": layer.w2_weight.shape[2] * 2 // block_size,
}
assert (
weight_scale.shape[-1] == expected_blocks[name]
), f"Expected {name}_weight_scale.dim(2) == {expected_blocks[name]}, got {weight_scale.shape[-1]}"
else:
if weight_scale.shape[assert_dim] % 4 != 0:
logger.warning(
"NVFP4 %s_weight_scale K' not multiple of 4: shape=%s, group_size=%s",
name,
tuple(weight_scale.shape),
getattr(self.quant_config, "group_size", None),
)
assert (
weight_scale.dtype == torch.float8_e4m3fn
), f"{name} Weight Blockscale must be represented as FP8-E4M3"
# Weight processing based on strategy
if (
self.enable_flashinfer_trtllm_moe
and reorder_rows_for_gated_act_gemm is not None
and shuffle_matrix_sf_a is not None
):
from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import (
align_fp4_moe_weights_for_flashinfer_trtllm,
)
# FlashInfer TRTLLM processing - handles both w13 and w2
align_fp4_moe_weights_for_flashinfer_trtllm(layer)
# TRTLLM doesn't read *_blockscale_swizzled; alias to free the
# placeholders from create_weights.
layer.w13_blockscale_swizzled = layer.w13_weight_scale
layer.w2_blockscale_swizzled = layer.w2_weight_scale
else:
# CUTLASS processing - handle w13 and w2 separately
if self._is_cutedsl_v2_standard and layer.moe_runner_config.is_gated:
# CuteDSL v2 only: interleave the two logical W13 halves in
# 64-row chunks for the fused SwiGLU GEMM1 layout expected by
# CuteDslMoEWrapper. The v1 (deepep) path uses
# grouped_gemm_nt_masked which expects plain contiguous halves.
from sglang.srt.layers.moe.moe_runner.flashinfer_cutedsl import (
interleave_w13_halves,
)
layer.w13_weight = Parameter(
interleave_w13_halves(
layer.w13_weight.view(torch.uint8), group_size=64, dim=1
).contiguous(),
requires_grad=False,
)
layer.w13_weight_scale = Parameter(
interleave_w13_halves(
layer.w13_weight_scale, group_size=64, dim=1
).contiguous(),
requires_grad=False,
)
# Process w13 weights
w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale)
alias_or_bind_derived_param(
layer,
"w13_weight_scale",
"w13_blockscale_swizzled",
w13_blockscale_swizzled,
)
w13_weight = layer.w13_weight
intermediate_size_pad = w13_blockscale_swizzled.size(1) - w13_weight.size(1)
if intermediate_size_pad:
# padding gated activations will require to split w1 and w3
# and pad them individually
assert not layer.moe_runner_config.is_gated, (
"The intermediate size required padding, "
"but padding is also implemented for gated activations"
)
copy_or_rebind_param(
layer,
"w13_weight",
torch.nn.functional.pad(
w13_weight, (0, 0, 0, intermediate_size_pad)
),
)
copy_or_rebind_param(
layer,
"w2_weight",
torch.nn.functional.pad(
layer.w2_weight, (0, intermediate_size_pad // 2, 0, 0)
),
)
copy_or_rebind_param(
layer,
"w2_weight_scale",
torch.nn.functional.pad(
layer.w2_weight_scale, (0, intermediate_size_pad // 16)
),
)
# Process w2 weights
w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
alias_or_bind_derived_param(
layer,
"w2_weight_scale",
"w2_blockscale_swizzled",
w2_blockscale_swizzled,
)
if self._is_cutedsl_v2_standard:
# CuteDSL v2 only: convert blockscales to MMA layout for
# CuteDslMoEWrapper. The v1 (deepep) path uses the
# swizzled blockscales directly via flashinfer_cutedsl_moe_masked.
from flashinfer.cute_dsl.utils import convert_sf_to_mma_layout
from sglang.srt.layers.moe.moe_runner.flashinfer_cutedsl import (
_FP4_SF_VEC_SIZE,
)
sf_vec_size = _FP4_SF_VEC_SIZE
num_local_experts = layer.w13_weight.shape[0]
w13_m = layer.w13_weight.shape[1]
w13_k = layer.w13_weight.shape[2] * 2
w2_m = layer.w2_weight.shape[1]
w2_k = layer.w2_weight.shape[2] * 2
layer.w13_blockscale_mma = Parameter(
convert_sf_to_mma_layout(
layer.w13_blockscale_swizzled.contiguous()
.view(torch.uint8)
.reshape(-1),
m=w13_m,
k=w13_k,
num_groups=num_local_experts,
sf_vec_size=sf_vec_size,
),
requires_grad=False,
)
layer.w2_blockscale_mma = Parameter(
convert_sf_to_mma_layout(
layer.w2_blockscale_swizzled.contiguous()
.view(torch.uint8)
.reshape(-1),
m=w2_m,
k=w2_k,
num_groups=num_local_experts,
sf_vec_size=sf_vec_size,
),
requires_grad=False,
)
# Both flashinfer cutlass and regular cutlass use same processing for w2
# Set up CUTLASS MoE parameters (reuse to keep CUDA graph stable)
device = layer.w13_weight.device
inter_size = layer.w2_weight.shape[2] * 2
hidden_size = layer.w13_weight.shape[2] * 2
existing_params = getattr(layer, "cutlass_moe_params", None)
if (
existing_params is None
or existing_params.cutlass_moe_type != CutlassMoEType.BlockscaledFP4
or existing_params.num_experts != layer.num_experts
or existing_params.intermediate_size_per_partition != inter_size
or existing_params.hidden_size != hidden_size
or existing_params.device != device
):
layer.cutlass_moe_params = CutlassMoEParams(
CutlassMoEType.BlockscaledFP4,
device,
num_experts=layer.num_experts, # global num experts
intermediate_size_per_partition=inter_size, # n
hidden_size=hidden_size,
) # k
@property
def load_up_proj_weight_first(self) -> bool:
# Load W13 as [Up, Gate] for FlashInfer CUTLASS and CuteDSL v2 kernels.
# The CuteDSL v1 (deepep) path uses [Gate, Up] -- do NOT flip.
return self.moe_runner_config.is_gated and (
self.enable_flashinfer_cutlass_moe or self._is_cutedsl_v2_standard
)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
moe_runner_backend = get_moe_runner_backend()
if moe_runner_backend.is_auto():
if is_cuda() and (8, 0) <= get_device_capability() < (10, 0):
moe_runner_backend = MoeRunnerBackend.MARLIN
else:
# TRTLLM is currently the most performant and tested FP4 MoE
# backend, so use it as the default.
moe_runner_backend = MoeRunnerBackend.FLASHINFER_TRTLLM
self._moe_runner_backend = moe_runner_backend
if moe_runner_backend.is_flashinfer_cutedsl():
import sglang.srt.layers.moe.moe_runner.flashinfer_cutedsl # noqa: F401 – triggers @register_fused_func
if not moe_runner_backend.is_flashinfer_cutlass():
self.runner = MoeRunner(moe_runner_backend, moe_runner_config)
def apply(
self,
layer: FusedMoE,
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
# Note: dispatch_output may be a DeepEPLLDispatchOutput (no topk_output
# attribute -- topk_ids/topk_weights live directly on the dispatch
# tuple). Defer per-attribute access to the branches that actually
# consume them.
activation = self.moe_runner_config.activation
moe_runner_backend = getattr(
self, "_moe_runner_backend", get_moe_runner_backend()
)
assert (
activation in _SUPPORTED_ACT_STRS
), f"{activation=} not in supported {_SUPPORTED_ACT_STRS}"
moe_runner_config = self.moe_runner_config
if moe_runner_backend.is_marlin():
from sglang.srt.layers.moe.moe_runner.marlin import MarlinMoeQuantInfo
expert_map = None
global_num_experts = -1
if hasattr(layer, "dispatcher") and hasattr(
layer.dispatcher, "local_expert_mapping"
):
expert_map = layer.dispatcher.local_expert_mapping
if expert_map is not None:
global_num_experts = self.moe_runner_config.num_experts
quant_info = MarlinMoeQuantInfo(
w13_qweight=layer.w13_weight,
w2_qweight=layer.w2_weight,
w13_scales=layer.w13_weight_scale,
w2_scales=layer.w2_weight_scale,
w13_g_idx_sort_indices=None,
w2_g_idx_sort_indices=None,
weight_bits=4,
w13_global_scale=layer.w13_weight_scale_2,
w2_global_scale=layer.w2_weight_scale_2,
expert_map=expert_map,
global_num_experts=global_num_experts,
)
return self.runner.run(dispatch_output, quant_info)
# FlashInfer TRTLLM FP4 path
if self.enable_flashinfer_trtllm_moe and hasattr(layer, "g1_scale_c"):
from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import (
FlashInferTrtllmFp4MoeQuantInfo,
)
from sglang.srt.layers.moe.utils import RoutingMethodType
# Determine routing method type based on layer configuration
routing_method_type = getattr(
layer, "routing_method_type", RoutingMethodType.Default
)
quant_info = FlashInferTrtllmFp4MoeQuantInfo(
w13_weight=layer.w13_weight.data,
w2_weight=layer.w2_weight.data,
w13_weight_scale=layer.w13_weight_scale.data,
w2_weight_scale=layer.w2_weight_scale.data,
g1_scale_c=layer.g1_scale_c.data,
g1_alphas=layer.g1_alphas.data,
g2_alphas=layer.g2_alphas.data,
w13_input_scale_quant=layer.w13_input_scale_quant,
global_num_experts=layer.num_experts,
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
local_num_experts=layer.num_local_experts,
intermediate_size_per_partition=layer.intermediate_size_per_partition,
routing_method_type=routing_method_type,
use_per_token_activation=self.quant_config.use_per_token_activation,
)
return self.runner.run(dispatch_output, quant_info)
if self.enable_flashinfer_cutedsl_moe:
from sglang.srt.layers.moe.moe_runner.flashinfer_cutedsl import (
CuteDslFp4MoeQuantInfo,
ensure_cutedsl_wrapper,
)
if self._is_cutedsl_v1_deepep:
# v1 path: DeepEP low-latency + flashinfer_cutedsl_moe_masked.
# Weights are [Gate, Up] (non-interleaved) with swizzled blockscales.
quant_info = CuteDslFp4MoeQuantInfo(
w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight,
w13_weight_sf=layer.w13_blockscale_swizzled,
w2_weight_sf=layer.w2_blockscale_swizzled,
w1_alpha=layer.g1_alphas,
w2_alpha=layer.g2_alphas,
a1_scale=layer.w13_input_scale_quant,
a2_scale=layer.w2_input_scale_quant,
use_nvfp4_dispatch=MOE_NVFP4_DISPATCH,
down_gemm_overlap_args=getattr(
self.runner, "down_gemm_overlap_args", None
),
)
return self.runner.run(dispatch_output, quant_info)
# v2 standard path (a2a=none/flashinfer): uses CuteDslMoEWrapper
# with [Up, Gate] interleaved weights and MMA blockscales.
ensure_cutedsl_wrapper(layer)
w1_alpha, fc2_input_scale, w2_alpha = layer._cutedsl_scales
quant_info = CuteDslFp4MoeQuantInfo(
w13_weight=layer.w13_weight,
w2_weight=layer.w2_weight,
w13_weight_sf=getattr(
layer, "w13_blockscale_mma", layer.w13_blockscale_swizzled
),
w2_weight_sf=getattr(
layer, "w2_blockscale_mma", layer.w2_blockscale_swizzled
),
w1_alpha=w1_alpha,
w2_alpha=w2_alpha,
a1_scale=layer._cutedsl_input_scale,
a2_scale=fc2_input_scale,
wrapper=layer._cutedsl_wrapper,
)
return self.runner.run(dispatch_output, quant_info)
if self.enable_flashinfer_cutlass_moe:
from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import (
get_activation_type,
)
from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
assert (
not moe_runner_config.apply_router_weight_on_input
), "apply_router_weight_on_input is not supported for Flashinfer"
# Resolve the FlashInfer ActivationType honoring the gated flag,
# then verify the CUTLASS FP4 kernel supports it.
fi_activation = ActivationType(
get_activation_type(activation, is_gated=moe_runner_config.is_gated)
)
_CUTLASS_FP4_SUPPORTED = {
ActivationType.Swiglu,
ActivationType.Geglu,
ActivationType.Relu2,
ActivationType.Identity,
}
assert fi_activation in _CUTLASS_FP4_SUPPORTED, (
f"Activation {activation!r} (is_gated={moe_runner_config.is_gated}) "
f"maps to {fi_activation.name}, which is not supported by the "
"flashinfer cutlass fp4 moe kernel."
)
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
# and fp4 quantized weights loaded from the checkpoint
x = dispatch_output.hidden_states
x_sf = dispatch_output.hidden_states_scale
topk_output = dispatch_output.topk_output
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
output_dtype = torch.bfloat16
if DispatchOutputChecker.format_is_flashinfer(dispatch_output):
symm_output = dispatch_output.moe_output
else:
# If x_sf is not None, x is FP4 packed (half size), so we need * 2
# If x_sf is None, x is not packed, so output_col = x.shape[1]
output_col = x.shape[1]
if x_sf is not None and layer.moe_runner_config.is_gated:
output_col *= 2
with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
symm_output = torch.empty(
x.shape[0],
output_col,
dtype=output_dtype,
device=x.device,
)
# Forward parameterized-SwiGLU values (GPT-OSS-style clamped swiglu,
# e.g. alpha=1.702, limit=7.0) to the kernel; otherwise it computes
# vanilla SwiGLU and such models generate garbage. swiglu_beta=1.0
# (the +1 shift on the linear branch) follows the mxfp4 path.
swiglu_kwargs = {}
_gemm1_alpha = moe_runner_config.gemm1_alpha
_gemm1_limit = moe_runner_config.gemm1_clamp_limit
if _gemm1_alpha is not None or _gemm1_limit is not None:
_num_local_experts = layer.w13_weight.shape[0]
swiglu_kwargs["swiglu_alpha"] = torch.full(
(_num_local_experts,),
_gemm1_alpha if _gemm1_alpha is not None else 1.0,
dtype=torch.float32,
device=x.device,
)
swiglu_kwargs["swiglu_beta"] = torch.full(
(_num_local_experts,),
1.0,
dtype=torch.float32,
device=x.device,
)
if _gemm1_limit is not None:
swiglu_kwargs["swiglu_limit"] = torch.full(
(_num_local_experts,),
_gemm1_limit,
dtype=torch.float32,
device=x.device,
)
output = flashinfer_cutlass_fused_moe(
output=symm_output,
input=x,
token_selected_experts=topk_ids.to(torch.int),
token_final_scales=topk_weights,
fc1_expert_weights=layer.w13_weight.view(torch.long),
fc2_expert_weights=layer.w2_weight.view(torch.long),
output_dtype=output_dtype,
input_sf=x_sf,
# swizzled_input_sf intentionally omitted; not used for this path.
quant_scales=[
layer.w13_input_scale_quant,
layer.w13_blockscale_swizzled.view(torch.int32),
layer.g1_alphas,
layer.w2_input_scale_quant,
layer.w2_blockscale_swizzled.view(torch.int32),
layer.g2_alphas,
],
ep_size=layer.moe_ep_size,
ep_rank=layer.moe_ep_rank,
tp_size=layer.moe_tp_size,
tp_rank=layer.moe_tp_rank,
tune_max_num_tokens=next_power_of_2(x.shape[0]),
activation_type=fi_activation,
enable_alltoall=get_moe_a2a_backend().is_flashinfer(),
**swiglu_kwargs,
)[0]
return StandardCombineInput(hidden_states=output)
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
output = cutlass_moe_fp4(
a=x,
a1_gscale=layer.w13_input_scale_quant,
w1_fp4=layer.w13_weight,
w1_blockscale=layer.w13_blockscale_swizzled,
w1_alphas=layer.g1_alphas,
a2_gscale=layer.w2_input_scale_quant,
w2_fp4=layer.w2_weight,
w2_blockscale=layer.w2_blockscale_swizzled,
w2_alphas=layer.g2_alphas,
topk_weights=topk_weights,
topk_ids=topk_ids,
params=layer.cutlass_moe_params,
apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
no_combine=moe_runner_config.no_combine,
).to(x.dtype)
# Scale by routed_scaling_factor is fused into select_experts.
return StandardCombineInput(hidden_states=output)