| # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py | |
| from __future__ import annotations | |
| import logging | |
| from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.nn import Module | |
| from torch.nn.parameter import Parameter | |
| try: | |
| from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( | |
| apply_fp8_marlin_linear, | |
| prepare_fp8_layer_for_marlin, | |
| ) | |
| MARLIN_FP8_AVAILABLE = True | |
| except ImportError: | |
| MARLIN_FP8_AVAILABLE = False | |
| def dummy_func(*args, **kwargs): | |
| raise ImportError( | |
| "marlin FP8 requires some operators from vllm. Please install vllm." | |
| ) | |
| apply_fp8_marlin_linear = prepare_fp8_layer_for_marlin = dummy_func | |
| from sglang.srt.distributed import get_tensor_model_parallel_world_size | |
| from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading | |
| from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig | |
| from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmMoeQuantInfo | |
| from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo | |
| from sglang.srt.layers.parameter import ( | |
| BlockQuantScaleParameter, | |
| ModelWeightParameter, | |
| PerTensorScaleParameter, | |
| ) | |
| from sglang.srt.layers.quantization.base_config import ( | |
| FusedMoEMethodBase, | |
| LinearMethodBase, | |
| QuantizationConfig, | |
| QuantizeMethodBase, | |
| ) | |
| from sglang.srt.layers.quantization.fp8_kernel import ( | |
| fp8_dtype, | |
| is_fp8_fnuz, | |
| per_token_group_quant_fp8, | |
| scaled_fp8_quant, | |
| ) | |
| from sglang.srt.layers.quantization.fp8_utils import ( | |
| apply_fp8_linear, | |
| can_auto_enable_marlin_fp8, | |
| cutlass_fp8_supported, | |
| dispatch_w8a8_block_fp8_linear, | |
| input_to_float8, | |
| normalize_e4m3fn_to_e4m3fnuz, | |
| ) | |
| from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod | |
| from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod | |
| from sglang.srt.layers.quantization.utils import ( | |
| all_close_1d, | |
| convert_to_channelwise, | |
| is_layer_skipped, | |
| per_tensor_dequantize, | |
| requantize_with_max_scale, | |
| ) | |
| from sglang.srt.utils import ( | |
| cpu_has_amx_support, | |
| get_bool_env_var, | |
| is_cpu, | |
| is_cuda, | |
| is_hip, | |
| is_npu, | |
| is_sm90_supported, | |
| is_sm100_supported, | |
| log_info_on_rank0, | |
| next_power_of_2, | |
| print_warning_once, | |
| set_weight_attrs, | |
| use_intel_amx_backend, | |
| ) | |
| if TYPE_CHECKING: | |
| from sglang.srt.layers.moe.token_dispatcher import ( | |
| CombineInput, | |
| DispatchOutput, | |
| StandardDispatchOutput, | |
| ) | |
| from sglang.srt.layers.moe.topk import TopKOutput | |
| from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config | |
| _is_hip = is_hip() | |
| _is_cuda = is_cuda() | |
| _is_npu = is_npu() | |
| _is_cpu_amx_available = cpu_has_amx_support() | |
| _is_cpu = is_cpu() | |
| _is_fp8_fnuz = is_fp8_fnuz() | |
| _use_hip_int4 = get_bool_env_var("SGLANG_INT4_WEIGHT") | |
| _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip | |
| if _is_hip and (_use_aiter or _use_hip_int4): | |
| from aiter import ActivationType, QuantType | |
| from aiter.fused_moe import fused_moe | |
| from aiter.ops.shuffle import shuffle_weight | |
| ACTIVATION_SCHEMES = ["static", "dynamic"] | |
| logger = logging.getLogger(__name__) | |
| class Fp8Config(QuantizationConfig): | |
| """Config class for FP8.""" | |
| def __init__( | |
| self, | |
| is_checkpoint_fp8_serialized: bool = False, | |
| activation_scheme: str = "dynamic", | |
| ignored_layers: Optional[List[str]] = None, | |
| weight_block_size: List[int] = None, | |
| ) -> None: | |
| self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized | |
| if is_checkpoint_fp8_serialized: | |
| log_info_on_rank0(logger, "Detected fp8 checkpoint.") | |
| if activation_scheme not in ACTIVATION_SCHEMES: | |
| raise ValueError(f"Unsupported activation scheme {activation_scheme}") | |
| self.activation_scheme = activation_scheme | |
| self.ignored_layers = ignored_layers or [] | |
| if weight_block_size is not None: | |
| if not is_checkpoint_fp8_serialized: | |
| raise ValueError( | |
| f"The block-wise quantization only supports fp8-serialized checkpoint for now." | |
| ) | |
| if len(weight_block_size) != 2: | |
| raise ValueError( | |
| f"The quantization block size of weight must have 2 dimensions, but got {len(weight_block_size)} dimensions." | |
| ) | |
| if activation_scheme != "dynamic": | |
| raise ValueError( | |
| f"The block-wise quantization only supports dynamic activation scheme for now, but got {activation_scheme} activation scheme." | |
| ) | |
| self.weight_block_size = weight_block_size | |
| def get_name(cls) -> str: | |
| return "fp8" | |
| def get_supported_act_dtypes(cls) -> List[torch.dtype]: | |
| return [torch.bfloat16, torch.half] | |
| def get_min_capability(cls) -> int: | |
| return 80 | |
| def get_config_filenames(cls) -> List[str]: | |
| return [] | |
| def from_config(cls, config: Dict[str, Any]) -> Fp8Config: | |
| quant_method = cls.get_from_keys(config, ["quant_method"]) | |
| is_checkpoint_fp8_serialized = "fp8" in quant_method | |
| activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) | |
| ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) | |
| weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None) | |
| return cls( | |
| is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, | |
| activation_scheme=activation_scheme, | |
| ignored_layers=ignored_layers, | |
| weight_block_size=weight_block_size, | |
| ) | |
| def get_quant_method( | |
| self, layer: torch.nn.Module, prefix: str | |
| ) -> Optional[QuantizeMethodBase]: | |
| from sglang.srt.layers.linear import LinearBase | |
| from sglang.srt.layers.moe.fused_moe_triton import FusedMoE | |
| if isinstance(layer, LinearBase): | |
| if is_layer_skipped(prefix, self.ignored_layers): | |
| return UnquantizedLinearMethod() | |
| return Fp8LinearMethod(self) | |
| elif isinstance(layer, FusedMoE): | |
| return Fp8MoEMethod(self) | |
| return None | |
| def get_scaled_act_names(self) -> List[str]: | |
| return [] | |
| class Fp8LinearMethod(LinearMethodBase): | |
| """Linear method for FP8. | |
| Supports loading FP8 checkpoints with static weight scale and | |
| dynamic/static activation scale. | |
| Also supports loading quantized FP16/BF16 model checkpoints with dynamic | |
| activation scaling. The weight scaling factor will be initialized after | |
| the model weights are loaded. | |
| Limitations: | |
| 1. Only support per-tensor quantization due to torch._scaled_mm support. | |
| 2. Only support float8_e4m3fn data type due to the limitation of | |
| torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856) | |
| Args: | |
| quant_config: The quantization config. | |
| """ | |
| def __init__(self, quant_config: Union[Fp8Config, W4AFp8Config]): | |
| self.quant_config = quant_config | |
| self.cutlass_fp8_supported = cutlass_fp8_supported() | |
| # For GPUs that lack FP8 hardware support, we can leverage the Marlin | |
| # kernel for fast weight-only FP8 quantization | |
| self.use_marlin = False | |
| if _is_cuda and MARLIN_FP8_AVAILABLE: | |
| force_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN") | |
| auto_enable = can_auto_enable_marlin_fp8() | |
| self.use_marlin = force_marlin or auto_enable | |
| self.block_quant = self.quant_config.weight_block_size is not None | |
| self.w8a8_block_fp8_linear = dispatch_w8a8_block_fp8_linear() | |
| def create_weights( | |
| self, | |
| layer: torch.nn.Module, | |
| input_size_per_partition: int, | |
| output_partition_sizes: List[int], | |
| input_size: int, | |
| output_size: int, | |
| params_dtype: torch.dtype, | |
| **extra_weight_attrs, | |
| ): | |
| output_size_per_partition = sum(output_partition_sizes) | |
| weight_loader = extra_weight_attrs.get("weight_loader") | |
| tp_size = get_tensor_model_parallel_world_size() | |
| if self.block_quant: | |
| block_n, block_k = ( | |
| self.quant_config.weight_block_size[0], | |
| self.quant_config.weight_block_size[1], | |
| ) | |
| # Required by row parallel | |
| if tp_size > 1 and input_size // input_size_per_partition == tp_size: | |
| if input_size_per_partition % block_k != 0: | |
| raise ValueError( | |
| f"Weight input_size_per_partition = " | |
| f"{input_size_per_partition} is not divisible by " | |
| f"weight quantization block_k = {block_k}." | |
| ) | |
| # Required by column parallel or enabling merged weights | |
| if ( | |
| tp_size > 1 and output_size // output_size_per_partition == tp_size | |
| ) or len(output_partition_sizes) > 1: | |
| for output_partition_size in output_partition_sizes: | |
| if output_partition_size % block_n != 0: | |
| raise ValueError( | |
| f"Weight output_partition_size = " | |
| f"{output_partition_size} is not divisible by " | |
| f"weight quantization block_n = {block_n}." | |
| ) | |
| layer.logical_widths = output_partition_sizes | |
| layer.input_size_per_partition = input_size_per_partition | |
| layer.output_size_per_partition = output_size_per_partition | |
| layer.orig_dtype = params_dtype | |
| # WEIGHT | |
| weight_dtype = ( | |
| torch.float8_e4m3fn | |
| if self.quant_config.is_checkpoint_fp8_serialized | |
| else params_dtype | |
| ) | |
| weight = ModelWeightParameter( | |
| data=torch.empty( | |
| output_size_per_partition, input_size_per_partition, dtype=weight_dtype | |
| ), | |
| input_dim=1, | |
| output_dim=0, | |
| weight_loader=weight_loader, | |
| ) | |
| layer.register_parameter("weight", weight) | |
| # If checkpoint is serialized fp8, load them. | |
| # Otherwise, wait until process_weights_after_loading. | |
| if self.quant_config.is_checkpoint_fp8_serialized: | |
| # WEIGHT SCALE | |
| if self.block_quant: | |
| if hasattr(self.quant_config, "activation_scheme"): | |
| assert self.quant_config.activation_scheme == "dynamic" | |
| elif hasattr(self.quant_config, "linear_activation_scheme"): | |
| assert self.quant_config.linear_activation_scheme == "dynamic" | |
| scale = BlockQuantScaleParameter( | |
| data=torch.empty( | |
| (output_size_per_partition + block_n - 1) // block_n, | |
| (input_size_per_partition + block_k - 1) // block_k, | |
| dtype=torch.float32, | |
| ), | |
| input_dim=1, | |
| output_dim=0, | |
| weight_loader=weight_loader, | |
| ) | |
| scale[:] = torch.finfo(torch.float32).min | |
| layer.register_parameter("weight_scale_inv", scale) | |
| else: | |
| scale = PerTensorScaleParameter( | |
| data=torch.empty(len(output_partition_sizes), dtype=torch.float32), | |
| weight_loader=weight_loader, | |
| ) | |
| scale[:] = torch.finfo(torch.float32).min | |
| layer.register_parameter("weight_scale", scale) | |
| # INPUT ACTIVATION SCALE | |
| if ( | |
| hasattr(self.quant_config, "activation_scheme") | |
| and self.quant_config.activation_scheme == "static" | |
| ) or ( | |
| hasattr(self.quant_config, "linear_activation_scheme") | |
| and self.quant_config.linear_activation_scheme == "static" | |
| ): | |
| scale = PerTensorScaleParameter( | |
| data=torch.empty(len(output_partition_sizes), dtype=torch.float32), | |
| weight_loader=weight_loader, | |
| ) | |
| scale[:] = torch.finfo(torch.float32).min | |
| layer.register_parameter("input_scale", scale) | |
| else: | |
| layer.register_parameter("input_scale", None) | |
| def process_weights_after_loading(self, layer: Module) -> None: | |
| if self.block_quant: | |
| # If ROCm, normalize the weights and scales to e4m3fnuz | |
| if _is_fp8_fnuz: | |
| # activation_scheme: dynamic | |
| weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( | |
| weight=layer.weight, | |
| weight_scale=layer.weight_scale_inv, | |
| input_scale=None, | |
| ) | |
| layer.input_scale = None | |
| elif _is_cpu: | |
| assert ( | |
| _is_cpu_amx_available | |
| ), "Fp8LinearMethod on CPU requires that CPU has AMX support" | |
| _amx_process_weight_after_loading(layer, ["weight"]) | |
| layer.weight_scale_inv = torch.nn.Parameter( | |
| layer.weight_scale_inv.data, requires_grad=False | |
| ) | |
| return | |
| else: | |
| weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data | |
| layer.weight.data = weight.data | |
| layer.weight_scale_inv.data = weight_scale.data | |
| else: | |
| layer.weight = Parameter(layer.weight.data, requires_grad=False) | |
| # If checkpoint not serialized fp8, quantize the weights. | |
| if not self.quant_config.is_checkpoint_fp8_serialized: | |
| if self.cutlass_fp8_supported or self.use_marlin: | |
| # apply per-channel quantization default as | |
| # cutlass sgl-kernel and marlin only support per-channel scale | |
| qweight, weight_scale = per_token_group_quant_fp8( | |
| layer.weight, layer.weight.shape[-1] | |
| ) | |
| weight_scale = weight_scale.t().contiguous() | |
| else: | |
| # per-tensor quantization | |
| qweight, weight_scale = input_to_float8(layer.weight) | |
| # Update the layer with the new values. | |
| layer.weight = Parameter(qweight.t(), requires_grad=False) | |
| layer.weight_scale = Parameter(weight_scale, requires_grad=False) | |
| layer.input_scale = None | |
| # If checkpoint is fp8, handle that there are N scales for N | |
| # shards in a fused module | |
| else: | |
| layer.weight_scale = Parameter( | |
| layer.weight_scale.data, requires_grad=False | |
| ) | |
| if ( | |
| hasattr(self.quant_config, "activation_scheme") | |
| and self.quant_config.activation_scheme == "static" | |
| ) or ( | |
| hasattr(self.quant_config, "linear_activation_scheme") | |
| and self.quant_config.linear_activation_scheme == "static" | |
| ): | |
| layer.input_scale = Parameter( | |
| layer.input_scale.data, requires_grad=False | |
| ) | |
| # cutlass sgl-kernel and marlin only support per-channel scale | |
| if self.cutlass_fp8_supported or self.use_marlin: | |
| weight = layer.weight | |
| weight_scale = convert_to_channelwise( | |
| layer.weight_scale, layer.logical_widths | |
| ) | |
| else: | |
| # Dequant -> Quant with max scale so we can run per tensor. | |
| weight = layer.weight | |
| weight_scale = layer.weight_scale | |
| # If ROCm, normalize the weights and scales to e4m3fnuz | |
| if _is_fp8_fnuz: | |
| weight, weight_scale, input_scale = ( | |
| normalize_e4m3fn_to_e4m3fnuz( | |
| weight=weight, | |
| weight_scale=weight_scale, | |
| input_scale=layer.input_scale, | |
| ) | |
| ) | |
| if input_scale is not None: | |
| layer.input_scale = Parameter( | |
| input_scale, requires_grad=False | |
| ) | |
| weight_scale, weight = requantize_with_max_scale( | |
| weight=weight, | |
| weight_scale=weight_scale, | |
| logical_widths=layer.logical_widths, | |
| ) | |
| # Update layer with new values. | |
| layer.weight = Parameter(weight.t(), requires_grad=False) | |
| layer.weight_scale = Parameter(weight_scale, requires_grad=False) | |
| if ( | |
| hasattr(self.quant_config, "activation_scheme") | |
| and self.quant_config.activation_scheme == "static" | |
| ) or ( | |
| hasattr(self.quant_config, "linear_activation_scheme") | |
| and self.quant_config.linear_activation_scheme == "static" | |
| ): | |
| layer.input_scale = Parameter( | |
| layer.input_scale.max(), requires_grad=False | |
| ) | |
| if self.use_marlin: | |
| if self.block_quant: | |
| layer.weight_block_size = self.quant_config.weight_block_size | |
| prepare_fp8_layer_for_marlin(layer, not self.block_quant) | |
| # Activations not quantized for marlin. | |
| del layer.input_scale | |
| def apply( | |
| self, | |
| layer: torch.nn.Module, | |
| x: torch.Tensor, | |
| bias: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| if self.use_marlin: | |
| return apply_fp8_marlin_linear( | |
| input=x, | |
| weight=layer.weight, | |
| weight_scale=layer.weight_scale, | |
| workspace=layer.workspace, | |
| size_n=layer.output_size_per_partition, | |
| size_k=layer.input_size_per_partition, | |
| bias=bias, | |
| ) | |
| if self.block_quant: | |
| if use_intel_amx_backend(layer): | |
| return torch.ops.sgl_kernel.fp8_scaled_mm_cpu( | |
| x, | |
| layer.weight, | |
| layer.weight_scale_inv, | |
| self.quant_config.weight_block_size, | |
| bias, | |
| x.dtype, | |
| True, # is_vnni | |
| ) | |
| return self.w8a8_block_fp8_linear( | |
| input=x, | |
| weight=layer.weight, | |
| block_size=self.quant_config.weight_block_size, | |
| weight_scale=layer.weight_scale_inv, | |
| input_scale=None, | |
| bias=bias, | |
| ) | |
| return apply_fp8_linear( | |
| input=x, | |
| weight=layer.weight, | |
| weight_scale=layer.weight_scale, | |
| input_scale=layer.input_scale, | |
| bias=bias, | |
| cutlass_fp8_supported=self.cutlass_fp8_supported, | |
| use_per_token_if_dynamic=False, | |
| ) | |
| def get_tile_tokens_dim(num_tokens, top_k, num_experts): | |
| # Guess tokens per expert assuming perfect expert distribution first. | |
| num_tokens_per_expert = (num_tokens * top_k) // num_experts | |
| # And pad the number to the next power of 2. | |
| tile_tokens_dim = next_power_of_2(num_tokens_per_expert) | |
| # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel. | |
| tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) | |
| return tile_tokens_dim | |
| class Fp8MoEMethod(FusedMoEMethodBase): | |
| """MoE method for FP8. | |
| Supports loading FP8 checkpoints with static weight scale and | |
| dynamic/static activation scale. | |
| Also supports loading quantized FP16/BF16 model checkpoints with dynamic | |
| activation scaling. The weight scaling factor will be initialized after | |
| the model weights are loaded. | |
| Args: | |
| quant_config: The quantization config. | |
| """ | |
| def __init__(self, quant_config: Fp8Config): | |
| self.quant_config = quant_config | |
| self.block_quant = self.quant_config.weight_block_size is not None | |
| self.cutlass_fp8_supported = cutlass_fp8_supported() | |
| self.use_cutlass_fused_experts_fp8 = ( | |
| get_bool_env_var("SGLANG_CUTLASS_MOE") | |
| and self.cutlass_fp8_supported | |
| and self.block_quant | |
| and (is_sm100_supported() or is_sm90_supported()) | |
| ) | |
| def create_weights( | |
| self, | |
| layer: Module, | |
| num_experts: int, | |
| hidden_size: int, | |
| intermediate_size_per_partition: int, | |
| params_dtype: torch.dtype, | |
| **extra_weight_attrs, | |
| ): | |
| from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported | |
| if self.quant_config.is_checkpoint_fp8_serialized: | |
| params_dtype = torch.uint32 if _use_hip_int4 else torch.float8_e4m3fn | |
| tp_size = get_tensor_model_parallel_world_size() | |
| if self.block_quant: | |
| block_n, block_k = ( | |
| self.quant_config.weight_block_size[0], | |
| self.quant_config.weight_block_size[1], | |
| ) | |
| # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n. | |
| # Required by column parallel or enabling merged weights | |
| if intermediate_size_per_partition % block_n != 0: | |
| raise ValueError( | |
| f"The output_size of gate's and up's weight = " | |
| f"{intermediate_size_per_partition} is not divisible by " | |
| f"weight quantization block_n = {block_n}." | |
| ) | |
| if tp_size > 1: | |
| # Required by row parallel | |
| if intermediate_size_per_partition % block_k != 0: | |
| raise ValueError( | |
| f"The input_size of down's weight = " | |
| f"{intermediate_size_per_partition} is not divisible by " | |
| f"weight quantization block_k = {block_k}." | |
| ) | |
| # WEIGHTS | |
| if _is_hip and _use_hip_int4: | |
| # INT4 MoE weight - INT32 packed | |
| w13_weight = torch.nn.Parameter( | |
| torch.empty( | |
| num_experts, | |
| 2 * intermediate_size_per_partition, | |
| hidden_size // 8, | |
| dtype=params_dtype, | |
| ), | |
| requires_grad=False, | |
| ) | |
| w2_weight = torch.nn.Parameter( | |
| torch.empty( | |
| num_experts, | |
| hidden_size, | |
| intermediate_size_per_partition // 8, | |
| dtype=params_dtype, | |
| ), | |
| requires_grad=False, | |
| ) | |
| else: | |
| w13_weight = torch.nn.Parameter( | |
| torch.empty( | |
| num_experts, | |
| 2 * intermediate_size_per_partition, | |
| hidden_size, | |
| dtype=params_dtype, | |
| ), | |
| requires_grad=False, | |
| ) | |
| w2_weight = torch.nn.Parameter( | |
| torch.empty( | |
| num_experts, | |
| hidden_size, | |
| intermediate_size_per_partition, | |
| dtype=params_dtype, | |
| ), | |
| requires_grad=False, | |
| ) | |
| layer.register_parameter("w13_weight", w13_weight) | |
| set_weight_attrs(w13_weight, extra_weight_attrs) | |
| layer.register_parameter("w2_weight", w2_weight) | |
| set_weight_attrs(w2_weight, extra_weight_attrs) | |
| # WEIGHT_SCALES | |
| if self.block_quant: | |
| w13_weight_scale = torch.nn.Parameter( | |
| torch.ones( | |
| num_experts, | |
| 2 * ((intermediate_size_per_partition + block_n - 1) // block_n), | |
| (hidden_size + block_k - 1) // block_k, | |
| dtype=torch.float32, | |
| ), | |
| requires_grad=False, | |
| ) | |
| w2_weight_scale = torch.nn.Parameter( | |
| torch.ones( | |
| num_experts, | |
| (hidden_size + block_n - 1) // block_n, | |
| (intermediate_size_per_partition + block_k - 1) // block_k, | |
| dtype=torch.float32, | |
| ), | |
| requires_grad=False, | |
| ) | |
| layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) | |
| layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) | |
| assert self.quant_config.activation_scheme == "dynamic" | |
| if self.use_cutlass_fused_experts_fp8: | |
| self.ab_strides1 = torch.full( | |
| (num_experts,), | |
| hidden_size, | |
| device=w13_weight.device, | |
| dtype=torch.int64, | |
| ) | |
| self.c_strides1 = torch.full( | |
| (num_experts,), | |
| 2 * intermediate_size_per_partition, | |
| device=w13_weight.device, | |
| dtype=torch.int64, | |
| ) | |
| self.ab_strides2 = torch.full( | |
| (num_experts,), | |
| intermediate_size_per_partition, | |
| device=w2_weight.device, | |
| dtype=torch.int64, | |
| ) | |
| self.c_strides2 = torch.full( | |
| (num_experts,), | |
| hidden_size, | |
| device=w2_weight.device, | |
| dtype=torch.int64, | |
| ) | |
| self.workspace = torch.empty( | |
| 90000, device=w13_weight.device, dtype=torch.uint8 | |
| ) | |
| self.a_ptr = torch.empty( | |
| num_experts, device=w13_weight.device, dtype=torch.int64 | |
| ) | |
| self.b_ptr = torch.empty( | |
| num_experts, device=w13_weight.device, dtype=torch.int64 | |
| ) | |
| self.out_ptr = torch.empty( | |
| num_experts, device=w13_weight.device, dtype=torch.int64 | |
| ) | |
| self.a_scales_ptr = torch.empty( | |
| num_experts, device=w13_weight.device, dtype=torch.int64 | |
| ) | |
| self.b_scales_ptr = torch.empty( | |
| num_experts, device=w13_weight.device, dtype=torch.int64 | |
| ) | |
| self.expert_offsets = torch.empty( | |
| num_experts + 1, device=w13_weight.device, dtype=torch.int32 | |
| ) | |
| self.problem_sizes1 = torch.empty( | |
| num_experts, 3, device=w13_weight.device, dtype=torch.int32 | |
| ) | |
| self.problem_sizes2 = torch.empty( | |
| num_experts, 3, device=w13_weight.device, dtype=torch.int32 | |
| ) | |
| else: | |
| # Allocate 2 scales for w1 and w3 respectively. | |
| # They will be combined to a single scale after weight loading. | |
| w13_weight_scale = torch.nn.Parameter( | |
| torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False | |
| ) | |
| w2_weight_scale = torch.nn.Parameter( | |
| torch.ones(num_experts, dtype=torch.float32), requires_grad=False | |
| ) | |
| layer.register_parameter("w13_weight_scale", w13_weight_scale) | |
| layer.register_parameter("w2_weight_scale", w2_weight_scale) | |
| if _is_hip: # _use_aiter: TODO: add check back after triton kernel | |
| # ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling | |
| w13_weight_scale1 = torch.nn.Parameter( | |
| torch.ones( | |
| num_experts, | |
| 2 * intermediate_size_per_partition, | |
| dtype=torch.float32, | |
| ), | |
| requires_grad=False, | |
| ) | |
| w2_weight_scale1 = torch.nn.Parameter( | |
| torch.ones(num_experts, hidden_size, dtype=torch.float32), | |
| requires_grad=False, | |
| ) | |
| layer.register_parameter("w13_weight_scale1", w13_weight_scale1) | |
| layer.register_parameter("w2_weight_scale1", w2_weight_scale1) | |
| # Add the quantization method used (per tensor/grouped/channel) | |
| # to ensure the weight scales are loaded in properly | |
| extra_weight_attrs.update( | |
| {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} | |
| if self.block_quant | |
| else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} | |
| ) | |
| # If loading fp8 checkpoint, pass the weight loaders. | |
| # If loading an fp16 checkpoint, do not (we will quantize in | |
| # process_weights_after_loading() | |
| if self.quant_config.is_checkpoint_fp8_serialized: | |
| set_weight_attrs(w13_weight_scale, extra_weight_attrs) | |
| set_weight_attrs(w2_weight_scale, extra_weight_attrs) | |
| if _is_hip and _use_hip_int4: | |
| extra_weight_attrs.update( | |
| {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} | |
| ) | |
| set_weight_attrs(w13_weight_scale1, extra_weight_attrs) | |
| set_weight_attrs(w2_weight_scale1, extra_weight_attrs) | |
| # INPUT_SCALES | |
| if self.quant_config.activation_scheme == "static": | |
| if not self.quant_config.is_checkpoint_fp8_serialized: | |
| raise ValueError( | |
| "Found static activation scheme for checkpoint that " | |
| "was not serialized fp8." | |
| ) | |
| w13_input_scale = torch.nn.Parameter( | |
| torch.ones(num_experts, dtype=torch.float32), requires_grad=False | |
| ) | |
| layer.register_parameter("w13_input_scale", w13_input_scale) | |
| set_weight_attrs(w13_input_scale, extra_weight_attrs) | |
| w2_input_scale = torch.nn.Parameter( | |
| torch.ones(num_experts, dtype=torch.float32), requires_grad=False | |
| ) | |
| layer.register_parameter("w2_input_scale", w2_input_scale) | |
| set_weight_attrs(w2_input_scale, extra_weight_attrs) | |
| else: | |
| layer.w13_input_scale = None | |
| layer.w2_input_scale = None | |
| def process_weights_after_loading(self, layer: Module) -> None: | |
| if _is_hip and _use_hip_int4: | |
| self.process_weights_hip_int4(layer) | |
| return | |
| # Block quant doesn't need to process weights after loading | |
| if self.block_quant: | |
| # If ROCm, normalize the weights and scales to e4m3fnuz | |
| if _is_fp8_fnuz: | |
| # activation_scheme: dynamic | |
| w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( | |
| weight=layer.w13_weight, | |
| weight_scale=layer.w13_weight_scale_inv, | |
| input_scale=None, | |
| ) | |
| w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( | |
| weight=layer.w2_weight, | |
| weight_scale=layer.w2_weight_scale_inv, | |
| input_scale=None, | |
| ) | |
| # Reset the parameter | |
| layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) | |
| layer.w13_weight_scale_inv = torch.nn.Parameter( | |
| w13_weight_scale, requires_grad=False | |
| ) | |
| layer.w13_input_scale = None | |
| layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) | |
| layer.w2_weight_scale_inv = torch.nn.Parameter( | |
| w2_weight_scale, requires_grad=False | |
| ) | |
| layer.w2_input_scale = None | |
| if _use_aiter: | |
| # Pre-shuffle weights | |
| layer.w13_weight.data = shuffle_weight( | |
| layer.w13_weight.contiguous(), (16, 16) | |
| ) | |
| layer.w2_weight.data = shuffle_weight( | |
| layer.w2_weight.contiguous(), (16, 16) | |
| ) | |
| if _is_cpu: | |
| assert ( | |
| _is_cpu_amx_available | |
| ), "Fp8MoEMethod on CPU requires that CPU has AMX support" | |
| _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"]) | |
| return | |
| # If checkpoint is fp16 or bfloat16, quantize in place. | |
| if not self.quant_config.is_checkpoint_fp8_serialized: | |
| # If ROCm, fp8_dtype will be float8_e4m3fnuz (MI300x HW) | |
| w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) | |
| w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) | |
| # Re-initialize w13_scale because we directly quantize | |
| # merged w13 weights and generate a single scaling factor. | |
| layer.w13_weight_scale = torch.nn.Parameter( | |
| torch.ones( | |
| layer.num_local_experts, | |
| dtype=torch.float32, | |
| device=w13_weight.device, | |
| ), | |
| requires_grad=False, | |
| ) | |
| for expert in range(layer.num_local_experts): | |
| w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( | |
| scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) | |
| ) | |
| w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( | |
| scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) | |
| ) | |
| layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) | |
| layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) | |
| if _is_hip: | |
| self.process_weights_hip_scale_padding(layer) | |
| return | |
| # If checkpoint is fp8, we need to handle that the | |
| # MoE kernels require single activation scale and single weight | |
| # scale for w13 per expert. | |
| else: | |
| # Fp8 moe kernels require a single activation scale. | |
| # We take the max of all the scales in case they differ. | |
| if self.quant_config.activation_scheme == "static": | |
| if layer.w13_input_scale is None or layer.w2_input_scale is None: | |
| raise ValueError( | |
| "QuantConfig has static quantization, but found " | |
| "activation scales are None." | |
| ) | |
| if not all_close_1d(layer.w13_input_scale) or not all_close_1d( | |
| layer.w2_input_scale | |
| ): | |
| print_warning_once( | |
| "Found input_scales that are not equal for " | |
| "fp8 MoE layer. Using the maximum across experts " | |
| "for each layer. " | |
| ) | |
| layer.w13_input_scale = torch.nn.Parameter( | |
| layer.w13_input_scale.max(), requires_grad=False | |
| ) | |
| layer.w2_input_scale = torch.nn.Parameter( | |
| layer.w2_input_scale.max(), requires_grad=False | |
| ) | |
| # If ROCm, normalize the weights and scales to e4m3fnuz | |
| if _is_fp8_fnuz: | |
| # Normalize the weights and scales | |
| w13_weight, w13_weight_scale, w13_input_scale = ( | |
| normalize_e4m3fn_to_e4m3fnuz( | |
| layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale | |
| ) | |
| ) | |
| w2_weight, w2_weight_scale, w2_input_scale = ( | |
| normalize_e4m3fn_to_e4m3fnuz( | |
| layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale | |
| ) | |
| ) | |
| # Reset the parameter | |
| layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) | |
| layer.w13_weight_scale = torch.nn.Parameter( | |
| w13_weight_scale, requires_grad=False | |
| ) | |
| if w13_input_scale is not None: | |
| layer.w13_input_scale = torch.nn.Parameter( | |
| w13_input_scale, requires_grad=False | |
| ) | |
| layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) | |
| layer.w2_weight_scale = torch.nn.Parameter( | |
| w2_weight_scale, requires_grad=False | |
| ) | |
| if w2_input_scale is not None: | |
| layer.w2_input_scale = torch.nn.Parameter( | |
| w2_input_scale, requires_grad=False | |
| ) | |
| # Fp8 moe kernel needs single weight scale for w13 per expert. | |
| # We take the max then dequant and requant each expert. | |
| assert layer.w13_weight_scale is not None | |
| shard_size = layer.intermediate_size_per_partition | |
| max_w13_scales = layer.w13_weight_scale.max(dim=1).values | |
| for expert_id in range(layer.num_local_experts): | |
| start = 0 | |
| for shard_id in range(2): | |
| dq_weight = per_tensor_dequantize( | |
| layer.w13_weight[expert_id][start : start + shard_size, :], | |
| layer.w13_weight_scale[expert_id][shard_id], | |
| ) | |
| ( | |
| layer.w13_weight[expert_id][start : start + shard_size, :], | |
| _, | |
| ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) | |
| start += shard_size | |
| layer.w13_weight_scale = torch.nn.Parameter( | |
| max_w13_scales, requires_grad=False | |
| ) | |
| if _is_hip: | |
| self.process_weights_hip_scale_padding(layer) | |
| return | |
| def process_weights_hip_int4(self, layer: Module): | |
| # TODO: _use_aiter: add after triton kernel added | |
| # INT4-FP8 (INT4 MoE Weight, FP8 Compute) | |
| # Weight Permutation | |
| layer.w13_weight = torch.nn.Parameter( | |
| shuffle_weight(layer.w13_weight.data, (16, 16)), | |
| requires_grad=False, | |
| ) | |
| torch.cuda.empty_cache() | |
| layer.w2_weight = torch.nn.Parameter( | |
| shuffle_weight(layer.w2_weight.data, (16, 16)), | |
| requires_grad=False, | |
| ) | |
| torch.cuda.empty_cache() | |
| # INT4-FP8 : offset INT4 w13_weight_scale1 to single w13_weight_scale | |
| # Fp8 moe kernel needs single fp8 w13_weight_scale for w13 per expert. | |
| # We won't do requant each expert's fp8 weight (not direct available), | |
| # instead we adjust half of INT4 w13_weight_scale1 numbers | |
| assert layer.w13_weight_scale is not None | |
| shard_size = layer.intermediate_size_per_partition | |
| max_w13_scales = layer.w13_weight_scale.max(dim=1).values | |
| for expert_id in range(layer.num_local_experts): | |
| start = 0 | |
| max_w13_scale_fp8 = max_w13_scales[expert_id] | |
| for shard_id in range(2): | |
| if layer.w13_weight_scale[expert_id][shard_id] != max_w13_scale_fp8: | |
| int4_rescale = ( | |
| layer.w13_weight_scale[expert_id][shard_id] / max_w13_scale_fp8 | |
| ) | |
| layer.w13_weight_scale1[expert_id][ | |
| start : start + shard_size | |
| ] *= int4_rescale | |
| start += shard_size | |
| layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) | |
| # special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post GEMM scaling | |
| # optimal design - shall apply per-column weight_scale1 before GEMM, and weight_scale post | |
| for expert_id in range(layer.num_local_experts): | |
| layer.w13_weight_scale1[expert_id] *= max_w13_scales[expert_id] | |
| layer.w2_weight_scale1[expert_id] *= layer.w2_weight_scale[expert_id] | |
| def process_weights_hip_scale_padding(self, layer: Module): | |
| from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( | |
| padding_size, # Avoid circular import | |
| ) | |
| if _use_aiter: | |
| layer.w13_weight = torch.nn.Parameter( | |
| shuffle_weight(layer.w13_weight.data, (16, 16)), | |
| requires_grad=False, | |
| ) | |
| torch.cuda.empty_cache() | |
| layer.w2_weight = torch.nn.Parameter( | |
| shuffle_weight(layer.w2_weight.data, (16, 16)), | |
| requires_grad=False, | |
| ) | |
| torch.cuda.empty_cache() | |
| # ROCm (_use_aiter): using column-wise scaling | |
| layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1) | |
| layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1) | |
| elif get_bool_env_var("SGLANG_MOE_PADDING"): | |
| # If ROCm, apply weight padding (min. Mem channel contention) only if set | |
| layer.w13_weight = torch.nn.Parameter( | |
| F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), | |
| requires_grad=False, | |
| ) | |
| torch.cuda.empty_cache() | |
| layer.w2_weight = torch.nn.Parameter( | |
| F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), | |
| requires_grad=False, | |
| ) | |
| torch.cuda.empty_cache() | |
| def create_moe_runner( | |
| self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig | |
| ): | |
| from sglang.srt.layers import deep_gemm_wrapper | |
| from sglang.srt.layers.moe.utils import ( | |
| get_moe_a2a_backend, | |
| get_moe_runner_backend, | |
| ) | |
| self.moe_runner_config = moe_runner_config | |
| moe_runner_backend = get_moe_runner_backend() | |
| if moe_runner_backend.is_auto(): | |
| if ( | |
| deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM | |
| and get_moe_a2a_backend().is_deepep() | |
| ): | |
| moe_runner_backend = MoeRunnerBackend.DEEP_GEMM | |
| else: | |
| moe_runner_backend = MoeRunnerBackend.TRITON | |
| if moe_runner_backend.is_deep_gemm() or moe_runner_backend.is_triton(): | |
| self.runner = MoeRunner(moe_runner_backend, moe_runner_config) | |
| else: | |
| # TODO(cwan): refactor other backends | |
| pass | |
| def apply( | |
| self, | |
| layer: torch.nn.Module, | |
| dispatch_output: DispatchOutput, | |
| ) -> CombineInput: | |
| from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput | |
| x = dispatch_output.hidden_states | |
| topk_output = dispatch_output.topk_output | |
| moe_runner_config = self.moe_runner_config | |
| if use_intel_amx_backend(layer): | |
| from sglang.srt.layers.moe.topk import apply_topk_weights_cpu | |
| topk_weights, topk_ids, _ = topk_output | |
| x, topk_weights = apply_topk_weights_cpu( | |
| moe_runner_config.apply_router_weight_on_input, topk_weights, x | |
| ) | |
| output = torch.ops.sgl_kernel.fused_experts_cpu( | |
| x, | |
| layer.w13_weight, | |
| layer.w2_weight, | |
| topk_weights, | |
| topk_ids, | |
| False, # inplace See [Note] inplace should be False in fused_experts. | |
| False, # use_int8_w8a8 | |
| True, # use_fp8_w8a16 | |
| layer.w13_weight_scale_inv, # w1_scale | |
| layer.w2_weight_scale_inv, # w2_scale | |
| self.quant_config.weight_block_size, # block_size | |
| None, # a1_scale | |
| None, # a2_scale | |
| True, # is_vnni | |
| ) | |
| return StandardCombineInput(hidden_states=output) | |
| if _is_hip: | |
| ret = self.maybe_apply_hip_fused_experts( | |
| layer, | |
| x, | |
| topk_output, | |
| moe_runner_config.activation, | |
| moe_runner_config.no_combine, | |
| ) | |
| if ret is not None: | |
| return StandardCombineInput(hidden_states=ret) | |
| if self.use_cutlass_fused_experts_fp8: | |
| from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8 | |
| topk_weights, topk_ids, _ = topk_output | |
| output = cutlass_fused_experts_fp8( | |
| x, | |
| layer.w13_weight.transpose(1, 2), | |
| layer.w2_weight.transpose(1, 2), | |
| layer.w13_weight_scale_inv.transpose(1, 2), | |
| layer.w2_weight_scale_inv.transpose(1, 2), | |
| topk_weights, | |
| topk_ids, | |
| self.ab_strides1, | |
| self.c_strides1, | |
| self.ab_strides2, | |
| self.c_strides2, | |
| self.workspace, | |
| self.a_ptr, | |
| self.b_ptr, | |
| self.out_ptr, | |
| self.a_scales_ptr, | |
| self.b_scales_ptr, | |
| self.expert_offsets, | |
| self.problem_sizes1, | |
| self.problem_sizes2, | |
| use_fp8_blockscale=True, | |
| ) | |
| return StandardCombineInput(hidden_states=output) | |
| if self.runner.runner_backend.is_deep_gemm(): | |
| w13_weight = layer.w13_weight | |
| w2_weight = layer.w2_weight | |
| if self.block_quant: | |
| block_shape = self.quant_config.weight_block_size | |
| w13_scale = layer.w13_weight_scale_inv | |
| w2_scale = layer.w2_weight_scale_inv | |
| else: | |
| # Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm | |
| scale_block_size = 128 | |
| block_shape = [scale_block_size, scale_block_size] | |
| w13_scale_n = (w13_weight.shape[1] - 1) // scale_block_size + 1 | |
| w13_scale_k = (w13_weight.shape[2] - 1) // scale_block_size + 1 | |
| w13_scale = ( | |
| layer.w13_weight_scale.unsqueeze(1) | |
| .repeat_interleave(w13_scale_n, dim=1) | |
| .unsqueeze(2) | |
| .repeat_interleave(w13_scale_k, dim=2) | |
| ) | |
| w2_scale_n = (w2_weight.shape[1] - 1) // scale_block_size + 1 | |
| w2_scale_k = (w2_weight.shape[2] - 1) // scale_block_size + 1 | |
| w2_scale = ( | |
| layer.w2_weight_scale.unsqueeze(1) | |
| .repeat_interleave(w2_scale_n, dim=1) | |
| .unsqueeze(2) | |
| .repeat_interleave(w2_scale_k, dim=2) | |
| ) | |
| quant_info = DeepGemmMoeQuantInfo( | |
| w13_weight=w13_weight, | |
| w2_weight=w2_weight, | |
| use_fp8=True, | |
| w13_scale=w13_scale, | |
| w2_scale=w2_scale, | |
| block_shape=block_shape, | |
| ) | |
| elif self.runner.runner_backend.is_triton(): | |
| quant_info = TritonMoeQuantInfo( | |
| w13_weight=layer.w13_weight, | |
| w2_weight=layer.w2_weight, | |
| use_fp8_w8a8=True, | |
| w13_scale=( | |
| layer.w13_weight_scale_inv | |
| if self.block_quant | |
| else layer.w13_weight_scale | |
| ), | |
| w2_scale=( | |
| layer.w2_weight_scale_inv | |
| if self.block_quant | |
| else layer.w2_weight_scale | |
| ), | |
| a13_scale=layer.w13_input_scale, | |
| a2_scale=layer.w2_input_scale, | |
| block_shape=self.quant_config.weight_block_size, | |
| ) | |
| else: | |
| raise NotImplementedError( | |
| "Unsupported runner backend: %s" % self.runner.runner_backend | |
| ) | |
| return self.runner.run(dispatch_output, quant_info) | |
| def apply_with_router_logits( | |
| self, | |
| layer: torch.nn.Module, | |
| dispatch_output: StandardDispatchOutput, | |
| ) -> torch.Tensor: | |
| x = dispatch_output.hidden_states | |
| topk_output = dispatch_output.topk_output | |
| activation = self.moe_runner_config.activation | |
| routed_scaling_factor = self.moe_runner_config.routed_scaling_factor | |
| from flashinfer.fused_moe import trtllm_fp8_block_scale_moe | |
| from sglang.srt.layers.moe.topk import TopKOutputChecker | |
| assert TopKOutputChecker.format_is_bypassed(topk_output) | |
| router_logits = topk_output.router_logits | |
| topk_config = topk_output.topk_config | |
| assert ( | |
| activation == "silu" | |
| ), "Only silu is supported for flashinfer blockscale fp8 moe" | |
| a_q, a_sf = per_token_group_quant_fp8(x, self.quant_config.weight_block_size[1]) | |
| # NOTE: scales of hidden states have to be transposed! | |
| a_sf_t = a_sf.t().contiguous() | |
| assert ( | |
| topk_config.num_expert_group is not None | |
| and topk_config.topk_group is not None | |
| ), "Current trtllm_fp8_block_scale_moe kernel does not support these two arguments as None" | |
| correction_bias = ( | |
| None | |
| if topk_config.correction_bias is None | |
| else topk_config.correction_bias.to(x.dtype) | |
| ) | |
| return trtllm_fp8_block_scale_moe( | |
| routing_logits=router_logits.to(torch.float32), | |
| routing_bias=correction_bias, | |
| hidden_states=a_q, | |
| hidden_states_scale=a_sf_t, | |
| gemm1_weights=layer.w13_weight, | |
| gemm1_weights_scale=layer.w13_weight_scale_inv, | |
| gemm2_weights=layer.w2_weight, | |
| gemm2_weights_scale=layer.w2_weight_scale_inv, | |
| num_experts=layer.num_experts, | |
| top_k=topk_config.top_k, | |
| n_group=topk_config.num_expert_group, | |
| topk_group=topk_config.topk_group, | |
| intermediate_size=layer.w2_weight.shape[2], | |
| local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, | |
| local_num_experts=layer.num_local_experts, | |
| routed_scaling_factor=( | |
| routed_scaling_factor if routed_scaling_factor is not None else 1.0 | |
| ), | |
| tile_tokens_dim=get_tile_tokens_dim( | |
| x.shape[0], topk_config.top_k, layer.num_experts | |
| ), | |
| routing_method_type=2, # DeepSeek-styled routing method | |
| use_shuffled_weight=False, | |
| ) | |
| def maybe_apply_hip_fused_experts( | |
| self, | |
| layer: torch.nn.Module, | |
| x: torch.Tensor, | |
| topk_output: TopKOutput, | |
| activation: str = "silu", | |
| no_combine: bool = False, | |
| ) -> Optional[torch.Tensor]: | |
| topk_weights, topk_ids, _ = topk_output | |
| if _use_hip_int4: | |
| # TODO: add triton kernel and add check _use_aiter | |
| assert not no_combine, f"{no_combine=} is not supported." | |
| return fused_moe( | |
| x, | |
| layer.w13_weight, | |
| layer.w2_weight, | |
| topk_weights, | |
| topk_ids, | |
| quant_type=QuantType.per_Token, | |
| w1_scale=layer.w13_weight_scale1, | |
| w2_scale=layer.w2_weight_scale1, | |
| activation=( | |
| ActivationType.Silu if activation == "silu" else ActivationType.Gelu | |
| ), | |
| ) | |
| if _use_aiter: | |
| assert not no_combine, f"{no_combine=} is not supported." | |
| if self.block_quant: | |
| return fused_moe( | |
| x, | |
| layer.w13_weight, | |
| layer.w2_weight, | |
| topk_weights, | |
| topk_ids, | |
| w1_scale=layer.w13_weight_scale_inv, | |
| w2_scale=layer.w2_weight_scale_inv, | |
| quant_type=QuantType.per_128x128, | |
| activation=( | |
| ActivationType.Silu | |
| if activation == "silu" | |
| else ActivationType.Gelu | |
| ), | |
| expert_mask=None, | |
| ) | |
| else: | |
| return fused_moe( | |
| x, | |
| layer.w13_weight, | |
| layer.w2_weight, | |
| topk_weights, | |
| topk_ids, | |
| quant_type=QuantType.per_Token, | |
| w1_scale=layer.w13_weight_scale1, | |
| w2_scale=layer.w2_weight_scale1, | |
| activation=( | |
| ActivationType.Silu | |
| if activation == "silu" | |
| else ActivationType.Gelu | |
| ), | |
| ) | |
| return None | |
| class Fp8KVCacheMethod(BaseKVCacheMethod): | |
| """ | |
| Supports loading kv-cache scaling factors from FP8 checkpoints. | |
| """ | |
| def __init__(self, quant_config: Fp8Config): | |
| super().__init__(quant_config) | |
Xet Storage Details
- Size:
- 53.4 kB
- Xet hash:
- 4ad72f2cd0f16026d38d65fd4c17319968eb133448de207996b476f06254231a
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.