| from __future__ import annotations | |
| import logging | |
| from dataclasses import dataclass | |
| from fractions import Fraction | |
| from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union | |
| import torch | |
| from sglang.srt.layers.parameter import ( | |
| BasevLLMParameter, | |
| ChannelQuantScaleParameter, | |
| GroupQuantScaleParameter, | |
| PackedColumnParameter, | |
| PackedvLLMParameter, | |
| RowvLLMParameter, | |
| permute_param_layout_, | |
| ) | |
| from sglang.srt.layers.quantization.base_config import ( | |
| FusedMoEMethodBase, | |
| LinearMethodBase, | |
| QuantizationConfig, | |
| QuantizeMethodBase, | |
| ) | |
| from sglang.srt.layers.quantization.marlin_utils import ( | |
| apply_gptq_marlin_linear, | |
| check_marlin_supported, | |
| check_marlin_supports_shape, | |
| marlin_is_k_full, | |
| marlin_make_empty_g_idx, | |
| marlin_make_workspace, | |
| marlin_moe_permute_scales, | |
| marlin_permute_scales, | |
| marlin_repeat_scales_on_all_ranks, | |
| marlin_sort_g_idx, | |
| marlin_zero_points, | |
| verify_marlin_supported, | |
| ) | |
| from sglang.srt.layers.quantization.utils import ( | |
| get_linear_quant_method, | |
| get_scalar_types, | |
| replace_parameter, | |
| unpack_cols, | |
| ) | |
| if TYPE_CHECKING: | |
| from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig | |
| from sglang.srt.layers.moe.token_dispatcher import ( | |
| StandardDispatchOutput, | |
| CombineInput, | |
| ) | |
| from sglang.srt.utils import is_cuda | |
| _is_cuda = is_cuda() | |
| if _is_cuda: | |
| from sgl_kernel import fused_marlin_moe, gptq_gemm, gptq_marlin_repack, gptq_shuffle | |
| logger = logging.getLogger(__name__) | |
| ScalarType, scalar_types = get_scalar_types() | |
| def check_marlin_format(hf_quant_cfg: Dict[str, Any]) -> bool: | |
| # compat: gptqmodel and autogptq (eol) main use checkpoint_format: str | |
| # compat: autogptq <=0.7.1 is_marlin_format: bool | |
| return hf_quant_cfg.get("checkpoint_format") == "marlin" or hf_quant_cfg.get( | |
| "is_marlin_format", False | |
| ) | |
| def gptq_marlin_moe_repack( | |
| b_q_weight: torch.Tensor, | |
| perm: torch.Tensor, | |
| size_k: int, | |
| size_n: int, | |
| num_bits: int, | |
| ) -> torch.Tensor: | |
| num_experts = b_q_weight.shape[0] | |
| assert size_k % 16 == 0 | |
| output = torch.empty( | |
| (num_experts, size_k // 16, size_n * (num_bits // 2)), | |
| device=b_q_weight.device, | |
| dtype=b_q_weight.dtype, | |
| ) | |
| for e in range(num_experts): | |
| output[e] = gptq_marlin_repack(b_q_weight[e], perm[e], size_k, size_n, num_bits) | |
| return output | |
| class MarlinLinearLayerConfig: | |
| full_weight_shape: tuple[int, int] # [in, out] | |
| partition_weight_shape: tuple[int, int] | |
| weight_type: ScalarType | |
| act_type: torch.dtype | |
| group_size: int | |
| zero_points: bool | |
| has_g_idx: bool | |
| class GPTQConfig(QuantizationConfig): | |
| """Config class for GPTQ. | |
| Reference: https://arxiv.org/abs/2210.17323 | |
| """ | |
| def __init__( | |
| self, | |
| weight_bits: int, | |
| group_size: int, | |
| desc_act: bool, | |
| lm_head_quantized: bool, | |
| dynamic: Dict[str, Dict[str, Union[int, bool]]], | |
| ) -> None: | |
| # GPTQModel use `dynamic` config property to allow per module | |
| # quantization config so each module can be individually optimized. | |
| # Format is Dict[str, Dict] where key is a regex string that can | |
| # perform both positive ("+:" prefixed) or negative ("-:" prefixed) | |
| # matching of a module. | |
| # Default to positive match, override base quant config mode, if no | |
| # prefix is used. Value is in dict format of field key and override | |
| # value. | |
| # Negative matching will skip quantization init for this module | |
| # entirely: | |
| # non-quantized inference. More details and quantization examples can be | |
| # found at: https://github.com/ModelCloud/GPTQModel | |
| # Example: | |
| # # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9 | |
| # # last 1/4 of the layers 16-21 has 8bit and group_size 64 | |
| # dynamic = { | |
| # #`.*\.` matches the layers_node prefix | |
| # # positive match layer 10-15 | |
| # r"+:.*\.(?:1[0-5])\..*": {"bits": 8,}, | |
| # # positive match layer 16-21 | |
| # r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,}, | |
| # r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers | |
| # } | |
| super().__init__() | |
| self.dynamic = dynamic | |
| self.weight_bits = weight_bits | |
| self.group_size = group_size | |
| self.desc_act = desc_act | |
| self.lm_head_quantized = lm_head_quantized | |
| self.pack_factor = Fraction(32, self.weight_bits) | |
| if self.weight_bits not in [2, 3, 4, 8]: | |
| raise ValueError( | |
| "Currently, only 2/3/4/8-bit weight quantization is " | |
| f"supported for GPTQ, but got {self.weight_bits} bits." | |
| ) | |
| def __repr__(self) -> str: | |
| return ( | |
| f"GPTQConfig(weight_bits={self.weight_bits}, " | |
| f"group_size={self.group_size}, " | |
| f"desc_act={self.desc_act})," | |
| f"lm_head_quantized={self.lm_head_quantized}), " | |
| f"dynamic={self.dynamic}" | |
| ) | |
| def get_scaled_act_names(self) -> List[str]: | |
| """Returns the activation function names that should be post-scaled. | |
| For now, this is only used by AWQ. | |
| """ | |
| raise NotImplementedError | |
| def get_name(cls) -> str: | |
| return "gptq" | |
| def get_supported_act_dtypes(cls) -> List[torch.dtype]: | |
| return [torch.half] | |
| # Need to figure it out | |
| def get_min_capability(cls) -> int: | |
| return 60 | |
| def get_config_filenames(cls) -> List[str]: | |
| return ["quantize_config.json"] | |
| def from_config(cls, config: Dict[str, Any]) -> GPTQConfig: | |
| dynamic = cls.get_from_keys_or(config, ["dynamic"], default={}) | |
| dynamic = {} if dynamic is None else dynamic | |
| weight_bits = cls.get_from_keys(config, ["bits"]) | |
| group_size = cls.get_from_keys(config, ["group_size"]) | |
| desc_act = cls.get_from_keys(config, ["desc_act"]) | |
| lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) | |
| return cls(weight_bits, group_size, desc_act, lm_head_quantized, dynamic) | |
| def get_quant_method( | |
| self, layer: torch.nn.Module, prefix: str | |
| ) -> Optional[LinearMethodBase]: | |
| # Delay the import to avoid circular dependency | |
| from sglang.srt.layers.moe.fused_moe_triton import FusedMoE | |
| if isinstance(layer, FusedMoE): | |
| raise TypeError("GPTQ Method does not support MoE, please use gptq_marlin") | |
| else: | |
| return get_linear_quant_method( | |
| self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod | |
| ) | |
| class GPTQMarlinConfig(QuantizationConfig): | |
| """Config class for GPTQ Marlin""" | |
| # (num_bits, is_sym) -> quant_type | |
| TYPE_MAP = { | |
| (4, True): scalar_types.uint4b8, | |
| (8, True): scalar_types.uint8b128, | |
| } | |
| def __init__( | |
| self, | |
| weight_bits: int, | |
| group_size: int, | |
| desc_act: bool, | |
| is_sym: bool, | |
| lm_head_quantized: bool, | |
| dynamic: Dict[str, Dict[str, Union[int, bool]]], | |
| full_config: Dict[str, Any], | |
| ) -> None: | |
| super().__init__() | |
| if desc_act and group_size == -1: | |
| # In this case, act_order == True is the same as act_order == False | |
| # (since we have only one group per output channel) | |
| desc_act = False | |
| # GPTQModel use `dynamic` config property to allow per module | |
| # quantization config so each module can be individually optimized. | |
| # Format is Dict[str, Dict] where key is a regex string that can | |
| # perform both positive ("+:" prefixed) or negative ("-:" prefixed) | |
| # matching of a module. | |
| # Default to positive match, override base quant config mode, if no | |
| # prefix is used. Value is in dict format of field key and override | |
| # value. | |
| # Negative matching will skip quantization init for this module | |
| # entirely: | |
| # non-quantized inference. More details and quantization examples can be | |
| # found at: https://github.com/ModelCloud/GPTQModel | |
| # Example: | |
| # # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9 | |
| # # last 1/4 of the layers 16-21 has 8bit and group_size 64 | |
| # dynamic = { | |
| # #`.*\.` matches the layers_node prefix | |
| # # positive match layer 10-15 | |
| # r"+:.*\.(?:1[0-5])\..*": {"bits": 8,}, | |
| # # positive match layer 16-21 | |
| # r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,}, | |
| # r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers | |
| # } | |
| self.dynamic = dynamic | |
| self.weight_bits = weight_bits | |
| self.is_sym = is_sym | |
| self.pack_factor = 32 // weight_bits # packed into int32 | |
| self.group_size = group_size | |
| self.desc_act = desc_act | |
| self.lm_head_quantized = lm_head_quantized | |
| self.full_config = full_config | |
| if (weight_bits, is_sym) not in self.TYPE_MAP: | |
| raise ValueError( | |
| "Unsupported quantization config: " f"bits={weight_bits}, sym={is_sym}" | |
| ) | |
| # (num_bits, is_sym) -> quant_type | |
| self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)] | |
| def __repr__(self) -> str: | |
| return ( | |
| f"GPTQMarlinConfig(quant_type={self.quant_type}, " | |
| f"group_size={self.group_size}, " | |
| f"desc_act={self.desc_act}, " | |
| f"lm_head_quantized={self.lm_head_quantized}), " | |
| f"dynamic={self.dynamic}" | |
| ) | |
| def get_scaled_act_names(self) -> List[str]: | |
| """Returns the activation function names that should be post-scaled. | |
| For now, this is only used by AWQ. | |
| """ | |
| raise NotImplementedError | |
| def get_name(cls) -> str: | |
| return "gptq_marlin" | |
| def get_supported_act_dtypes(cls) -> List[torch.dtype]: | |
| return [torch.half, torch.bfloat16] | |
| def get_min_capability(cls) -> int: | |
| return 80 | |
| def get_config_filenames(cls) -> List[str]: | |
| return ["quantize_config.json"] | |
| def from_config(cls, config: Dict[str, Any]) -> GPTQMarlinConfig: | |
| dynamic = cls.get_from_keys_or(config, ["dynamic"], default={}) | |
| dynamic = {} if dynamic is None else dynamic | |
| weight_bits = cls.get_from_keys(config, ["bits"]) | |
| group_size = cls.get_from_keys(config, ["group_size"]) | |
| desc_act = cls.get_from_keys(config, ["desc_act"]) | |
| is_sym = cls.get_from_keys(config, ["sym"]) | |
| lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) | |
| return cls( | |
| weight_bits, | |
| group_size, | |
| desc_act, | |
| is_sym, | |
| lm_head_quantized, | |
| dynamic, | |
| config, | |
| ) | |
| def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: | |
| is_marlin_format = check_marlin_format(hf_quant_cfg) | |
| can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg) | |
| is_valid_user_quant = ( | |
| user_quant is None or user_quant == "marlin" or user_quant == "gptq_marlin" | |
| ) | |
| if not is_marlin_format and can_convert and is_valid_user_quant: | |
| msg = ( | |
| "The model is convertible to {} during runtime." | |
| " Using {} kernel.".format(cls.get_name(), cls.get_name()) | |
| ) | |
| logger.info(msg) | |
| return cls.get_name() | |
| if not is_marlin_format and can_convert and user_quant == "gptq": | |
| logger.info( | |
| "Detected that the model can run with gptq_marlin" | |
| ", however you specified quantization=gptq explicitly," | |
| " so forcing gptq. Use quantization=gptq_marlin for" | |
| " faster inference" | |
| ) | |
| return None | |
| def get_quant_method( | |
| self, layer: torch.nn.Module, prefix: str | |
| ) -> Optional[QuantizeMethodBase]: | |
| # Delay the import to avoid circular dependency | |
| from sglang.srt.layers.moe.fused_moe_triton import FusedMoE | |
| if isinstance(layer, FusedMoE): | |
| return GPTQMarlinMoEMethod(self) | |
| return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod) | |
| def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]): | |
| quant_method = quant_config.get("quant_method", "").lower() | |
| num_bits = quant_config.get("bits") | |
| group_size = quant_config.get("group_size") | |
| sym = quant_config.get("sym") | |
| desc_act = quant_config.get("desc_act") | |
| if not _is_cuda: | |
| return False | |
| if quant_method != "gptq": | |
| return False | |
| # Marlin conversion is only valid if required properties are found | |
| if num_bits is None or group_size is None or sym is None or desc_act is None: | |
| return False | |
| if (num_bits, sym) not in cls.TYPE_MAP: | |
| return False | |
| return check_marlin_supported( | |
| quant_type=cls.TYPE_MAP[(num_bits, sym)], group_size=group_size | |
| ) | |
| class GPTQLinearMethod(LinearMethodBase): | |
| """Linear method for GPTQ. | |
| Args: | |
| quant_config: The GPTQ quantization config. | |
| """ | |
| def __init__(self, quant_config: GPTQConfig): | |
| self.quant_config = quant_config | |
| def create_weights( | |
| self, | |
| layer: torch.nn.Module, | |
| input_size_per_partition: int, | |
| output_partition_sizes: list[int], | |
| input_size: int, | |
| output_size: int, | |
| params_dtype: torch.dtype, | |
| **extra_weight_attrs, | |
| ): | |
| del output_size # Unused. | |
| weight_loader = extra_weight_attrs.get("weight_loader") | |
| if input_size_per_partition % self.quant_config.group_size != 0: | |
| raise ValueError( | |
| "The input size is not aligned with the quantized " | |
| "weight shape. This can be caused by too large " | |
| "tensor parallel size." | |
| ) | |
| output_size_per_partition = sum(output_partition_sizes) | |
| if output_size_per_partition % self.quant_config.pack_factor.numerator != 0: | |
| raise ValueError( | |
| "The output size is not aligned with the quantized " | |
| "weight shape. This can be caused by too large " | |
| "tensor parallel size." | |
| ) | |
| if self.quant_config.group_size != -1: | |
| group_size = self.quant_config.group_size | |
| else: | |
| group_size = input_size | |
| self.use_shuffle = True | |
| scale_and_zero_size = input_size // group_size | |
| scale_and_zero_input_dim = None | |
| if ( | |
| input_size != input_size_per_partition | |
| and self.quant_config.group_size != -1 | |
| ): | |
| if self.quant_config.desc_act: | |
| self.use_shuffle = False | |
| else: | |
| # we need to partition qzeros and scales for exllama kernel | |
| scale_and_zero_size = input_size_per_partition // group_size | |
| scale_and_zero_input_dim = 0 | |
| qweight = PackedvLLMParameter( | |
| data=torch.empty( | |
| input_size_per_partition // self.quant_config.pack_factor, | |
| output_size_per_partition, | |
| dtype=torch.int32, | |
| ), | |
| input_dim=0, | |
| output_dim=1, | |
| packed_dim=0, | |
| packed_factor=self.quant_config.pack_factor, | |
| weight_loader=weight_loader, | |
| ) | |
| g_idx = RowvLLMParameter( | |
| data=torch.tensor( | |
| [ | |
| i // self.quant_config.group_size | |
| for i in range(input_size_per_partition) | |
| ], | |
| dtype=torch.int32, | |
| ), | |
| input_dim=0, | |
| weight_loader=weight_loader, | |
| ) | |
| qzeros_args = { | |
| "data": torch.empty( | |
| scale_and_zero_size, | |
| output_size_per_partition // self.quant_config.pack_factor, | |
| dtype=torch.int32, | |
| ), | |
| "weight_loader": weight_loader, | |
| } | |
| weight_scale_args = { | |
| "data": torch.empty( | |
| scale_and_zero_size, | |
| output_size_per_partition, | |
| dtype=params_dtype, | |
| ), | |
| "weight_loader": weight_loader, | |
| } | |
| if scale_and_zero_input_dim is None: | |
| scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args) | |
| qzeros = PackedColumnParameter( | |
| output_dim=1, | |
| packed_dim=1, | |
| packed_factor=self.quant_config.pack_factor, | |
| **qzeros_args, | |
| ) | |
| else: | |
| scales = GroupQuantScaleParameter( | |
| output_dim=1, input_dim=0, **weight_scale_args | |
| ) | |
| qzeros = PackedvLLMParameter( | |
| input_dim=0, | |
| output_dim=1, | |
| packed_dim=1, | |
| packed_factor=self.quant_config.pack_factor, | |
| **qzeros_args, | |
| ) | |
| layer.register_parameter("qweight", qweight) | |
| layer.register_parameter("g_idx", g_idx) | |
| layer.register_parameter("qzeros", qzeros) | |
| layer.register_parameter("scales", scales) | |
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | |
| # for torch.compile | |
| layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False) | |
| layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False) | |
| layer.g_idx = torch.nn.Parameter(layer.g_idx.data, requires_grad=False) | |
| layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False) | |
| # exllama needs to shuffle the weight after the weight is loaded | |
| # here we do the shuffle on first forward pass | |
| if self.use_shuffle: | |
| if self.quant_config.desc_act: | |
| layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int) | |
| else: | |
| layer.g_idx.data = torch.empty( | |
| (0,), dtype=torch.int, device=layer.g_idx.device | |
| ) | |
| gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits) | |
| def apply( | |
| self, | |
| layer: torch.nn.Module, | |
| x: torch.Tensor, | |
| bias: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| out_shape = x.shape[:-1] + (layer.qweight.shape[-1],) | |
| reshaped_x = x.reshape(-1, x.shape[-1]) | |
| output = gptq_gemm( | |
| reshaped_x, | |
| layer.qweight, | |
| layer.qzeros, | |
| layer.scales, | |
| layer.g_idx, | |
| self.use_shuffle, | |
| self.quant_config.weight_bits, | |
| ) | |
| if bias is not None: | |
| output.add_(bias) | |
| return output.reshape(out_shape) | |
| class GPTQMarlinLinearMethod(LinearMethodBase): | |
| """Linear method for GPTQ Marlin. | |
| Args: | |
| quant_config: The GPTQ Marlin quantization config. | |
| """ | |
| _kernel_backends_being_used: set[str] = set() | |
| def __init__(self, quant_config: GPTQMarlinConfig) -> None: | |
| self.quant_config = quant_config | |
| # Verify supported on platform. | |
| verify_marlin_supported( | |
| quant_type=self.quant_config.quant_type, | |
| group_size=self.quant_config.group_size, | |
| ) | |
| def create_weights( | |
| self, | |
| layer: torch.nn.Module, | |
| input_size_per_partition: int, | |
| output_partition_sizes: list[int], | |
| input_size: int, | |
| output_size: int, | |
| params_dtype: torch.dtype, | |
| **extra_weight_attrs, | |
| ) -> None: | |
| output_size_per_partition = sum(output_partition_sizes) | |
| is_row_parallel = input_size != input_size_per_partition | |
| weight_loader = extra_weight_attrs.get("weight_loader") | |
| self.kernel_config = MarlinLinearLayerConfig( | |
| full_weight_shape=(input_size, output_size), | |
| partition_weight_shape=( | |
| input_size_per_partition, | |
| output_size_per_partition, | |
| ), | |
| weight_type=self.quant_config.quant_type, | |
| act_type=params_dtype, | |
| group_size=self.quant_config.group_size, | |
| zero_points=False, | |
| has_g_idx=self.quant_config.desc_act, | |
| ) | |
| # Normalize group_size | |
| if self.quant_config.group_size != -1: | |
| group_size = self.quant_config.group_size | |
| else: | |
| group_size = input_size | |
| # Determine sharding | |
| if marlin_repeat_scales_on_all_ranks( | |
| self.quant_config.desc_act, self.quant_config.group_size, is_row_parallel | |
| ): | |
| # By setting scale_dim == None, weight_loader will | |
| # repeat the scales on each GPU in TP>1 case. | |
| scales_and_zp_input_dim = None | |
| scales_and_zp_size = input_size // group_size | |
| else: | |
| # By setting scale_dim == 0, weight_loader will | |
| # shard the scales in TP>1 case. | |
| scales_and_zp_input_dim = 0 | |
| scales_and_zp_size = input_size_per_partition // group_size | |
| # Quantized weights | |
| qweight = PackedvLLMParameter( | |
| data=torch.empty( | |
| input_size_per_partition // self.quant_config.pack_factor, | |
| output_size_per_partition, | |
| dtype=torch.int32, | |
| ), | |
| input_dim=0, | |
| output_dim=1, | |
| packed_dim=0, | |
| packed_factor=self.quant_config.pack_factor, | |
| weight_loader=weight_loader, | |
| ) | |
| # Activation order | |
| g_idx = RowvLLMParameter( | |
| data=torch.empty( | |
| input_size_per_partition, | |
| dtype=torch.int32, | |
| ), | |
| input_dim=0, | |
| weight_loader=weight_loader, | |
| ) | |
| qzeros_args = { | |
| "data": torch.empty( | |
| scales_and_zp_size, | |
| output_size_per_partition // self.quant_config.pack_factor, | |
| dtype=torch.int32, | |
| ), | |
| "weight_loader": weight_loader, | |
| } | |
| weight_scale_args = { | |
| "data": torch.empty( | |
| scales_and_zp_size, | |
| output_size_per_partition, | |
| dtype=params_dtype, | |
| ), | |
| "weight_loader": weight_loader, | |
| } | |
| if scales_and_zp_input_dim is None: | |
| scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args) | |
| qzeros = PackedColumnParameter( | |
| output_dim=1, | |
| packed_dim=1, | |
| packed_factor=self.quant_config.pack_factor, | |
| **qzeros_args, | |
| ) | |
| else: | |
| scales = GroupQuantScaleParameter( | |
| output_dim=1, input_dim=0, **weight_scale_args | |
| ) | |
| qzeros = PackedvLLMParameter( | |
| input_dim=0, | |
| output_dim=1, | |
| packed_dim=1, | |
| packed_factor=self.quant_config.pack_factor, | |
| **qzeros_args, | |
| ) | |
| layer.register_parameter("qweight", qweight) | |
| layer.register_parameter("g_idx", g_idx) | |
| layer.register_parameter("scales", scales) | |
| layer.register_parameter("qzeros", qzeros) | |
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | |
| device = getattr(layer, "qweight").device | |
| c = self.kernel_config | |
| check_marlin_supports_shape( | |
| c.partition_weight_shape[1], # out_features | |
| c.partition_weight_shape[0], # in_features | |
| c.full_weight_shape[0], # in_features | |
| c.group_size, | |
| ) | |
| row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0] | |
| self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel) | |
| # Allocate marlin workspace. | |
| self.workspace = marlin_make_workspace(device) | |
| # Default names since marlin requires empty parameters for these, | |
| # TODO: remove this requirement from marlin (allow optional tensors) | |
| self.w_q_name = "qweight" | |
| self.w_s_name = "scales" | |
| self.w_zp_name = "qzeros" | |
| self.w_gidx_name = "g_idx" | |
| def _transform_param( | |
| layer: torch.nn.Module, name: Optional[str], fn: Callable | |
| ) -> None: | |
| if name is not None and getattr(layer, name, None) is not None: | |
| old_param = getattr(layer, name) | |
| new_param = fn(old_param) | |
| # replace the parameter with torch.nn.Parameter for TorchDynamo | |
| # compatibility | |
| replace_parameter( | |
| layer, name, torch.nn.Parameter(new_param.data, requires_grad=False) | |
| ) | |
| def transform_w_q(x): | |
| assert isinstance(x, BasevLLMParameter) | |
| permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) | |
| x.data = gptq_marlin_repack( | |
| x.data.contiguous(), | |
| perm=layer.g_idx_sort_indices, | |
| size_k=c.partition_weight_shape[0], | |
| size_n=c.partition_weight_shape[1], | |
| num_bits=c.weight_type.size_bits, | |
| ) | |
| return x | |
| def transform_w_s(x): | |
| assert isinstance(x, BasevLLMParameter) | |
| permute_param_layout_(x, input_dim=0, output_dim=1) | |
| x.data = marlin_permute_scales( | |
| x.data.contiguous(), | |
| size_k=c.partition_weight_shape[0], | |
| size_n=c.partition_weight_shape[1], | |
| group_size=c.group_size, | |
| ) | |
| return x | |
| if c.has_g_idx: | |
| g_idx, g_idx_sort_indices = marlin_sort_g_idx( | |
| getattr(layer, self.w_gidx_name) | |
| ) | |
| _transform_param(layer, self.w_gidx_name, lambda _: g_idx) | |
| layer.g_idx_sort_indices = g_idx_sort_indices | |
| else: | |
| setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device)) | |
| layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) | |
| if c.zero_points: | |
| grouped_k = ( | |
| c.partition_weight_shape[0] // c.group_size if c.group_size != -1 else 1 | |
| ) | |
| _transform_param( | |
| layer, | |
| self.w_zp_name, | |
| lambda x: marlin_zero_points( | |
| unpack_cols( | |
| x.t(), | |
| c.weight_type.size_bits, | |
| grouped_k, | |
| c.partition_weight_shape[1], | |
| ), | |
| size_k=grouped_k, | |
| size_n=c.partition_weight_shape[1], | |
| num_bits=c.weight_type.size_bits, | |
| ), | |
| ) | |
| else: | |
| setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device)) | |
| _transform_param(layer, self.w_q_name, transform_w_q) | |
| _transform_param(layer, self.w_s_name, transform_w_s) | |
| def apply( | |
| self, | |
| layer: torch.nn.Module, | |
| x: torch.Tensor, | |
| bias: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| c = self.kernel_config | |
| def _get_weight_params( | |
| layer: torch.nn.Module, | |
| ) -> tuple[ | |
| torch.Tensor, # w_q | |
| torch.Tensor, # w_s | |
| Optional[torch.Tensor], # w_zp, | |
| Optional[torch.Tensor], # w_gidx | |
| ]: | |
| return ( | |
| getattr(layer, self.w_q_name), | |
| getattr(layer, self.w_s_name), | |
| getattr(layer, self.w_zp_name or "", None), | |
| getattr(layer, self.w_gidx_name or "", None), | |
| ) | |
| w_q, w_s, w_zp, w_gidx = _get_weight_params(layer) | |
| # `process_weights_after_loading` will ensure w_zp and w_gidx are not | |
| # None for marlin | |
| return apply_gptq_marlin_linear( | |
| input=x, | |
| weight=w_q, | |
| weight_scale=w_s, | |
| weight_zp=w_zp, # type: ignore | |
| g_idx=w_gidx, # type: ignore | |
| g_idx_sort_indices=layer.g_idx_sort_indices, | |
| workspace=self.workspace, | |
| wtype=c.weight_type, | |
| input_size_per_partition=c.partition_weight_shape[0], | |
| output_size_per_partition=c.partition_weight_shape[1], | |
| is_k_full=self.is_k_full, | |
| bias=bias, | |
| ) | |
| class GPTQMarlinMoEMethod(FusedMoEMethodBase): | |
| """MoE Marlin method with quantization.""" | |
| def __init__(self, quant_config: GPTQMarlinConfig) -> None: | |
| self.quant_config = quant_config | |
| def create_weights( | |
| self, | |
| layer: torch.nn.Module, | |
| num_experts: int, | |
| hidden_size: int, | |
| intermediate_size_per_partition: int, | |
| params_dtype: torch.dtype, | |
| **extra_weight_attrs, | |
| ): | |
| # Delay the import to avoid circular dependency | |
| from sglang.srt.layers.linear import set_weight_attrs | |
| from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported | |
| self.is_k_full = (not self.quant_config.desc_act) or layer.moe_tp_size == 1 | |
| if self.quant_config.group_size != -1: | |
| scales_size13 = hidden_size // self.quant_config.group_size | |
| if self.quant_config.desc_act: | |
| w2_scales_size = intermediate_size_per_partition | |
| else: | |
| w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size | |
| scales_size2 = w2_scales_size // self.quant_config.group_size | |
| strategy = FusedMoeWeightScaleSupported.GROUP.value | |
| else: | |
| scales_size13 = 1 | |
| scales_size2 = 1 | |
| strategy = FusedMoeWeightScaleSupported.CHANNEL.value | |
| extra_weight_attrs.update({"quant_method": strategy, "is_transposed": True}) | |
| # Fused gate_up_proj (column parallel) | |
| w13_qweight = torch.nn.Parameter( | |
| torch.empty( | |
| num_experts, | |
| hidden_size // self.quant_config.pack_factor, | |
| 2 * intermediate_size_per_partition, | |
| dtype=torch.int32, | |
| ), | |
| requires_grad=False, | |
| ) | |
| layer.register_parameter("w13_qweight", w13_qweight) | |
| set_weight_attrs(w13_qweight, extra_weight_attrs) | |
| # down_proj (row parallel) | |
| w2_qweight = torch.nn.Parameter( | |
| torch.empty( | |
| num_experts, | |
| intermediate_size_per_partition // self.quant_config.pack_factor, | |
| hidden_size, | |
| dtype=torch.int32, | |
| ), | |
| requires_grad=False, | |
| ) | |
| layer.register_parameter("w2_qweight", w2_qweight) | |
| set_weight_attrs(w2_qweight, extra_weight_attrs) | |
| # up_proj scales | |
| w13_scales = torch.nn.Parameter( | |
| torch.empty( | |
| num_experts, | |
| scales_size13, | |
| 2 * intermediate_size_per_partition, | |
| dtype=torch.half, | |
| ), | |
| requires_grad=False, | |
| ) | |
| layer.register_parameter("w13_scales", w13_scales) | |
| set_weight_attrs(w13_scales, extra_weight_attrs) | |
| # down_proj scales | |
| w2_scales = torch.nn.Parameter( | |
| torch.empty(num_experts, scales_size2, hidden_size, dtype=torch.half), | |
| requires_grad=False, | |
| ) | |
| layer.register_parameter("w2_scales", w2_scales) | |
| set_weight_attrs(w2_scales, extra_weight_attrs) | |
| # dont shard the w2 scales when running act order | |
| set_weight_attrs(w2_scales, {"load_full_w2": self.quant_config.desc_act}) | |
| # up_proj scales | |
| w13_qzeros = torch.nn.Parameter( | |
| torch.empty( | |
| num_experts, | |
| scales_size13, | |
| 2 * intermediate_size_per_partition // self.quant_config.pack_factor, | |
| dtype=params_dtype, | |
| ), | |
| requires_grad=False, | |
| ) | |
| layer.register_parameter("w13_qzeros", w13_qzeros) | |
| set_weight_attrs(w13_qzeros, extra_weight_attrs) | |
| # down_proj scales | |
| w2_qzeros = torch.nn.Parameter( | |
| torch.empty( | |
| num_experts, | |
| scales_size2, | |
| hidden_size // self.quant_config.pack_factor, | |
| dtype=params_dtype, | |
| ), | |
| requires_grad=False, | |
| ) | |
| layer.register_parameter("w2_qzeros", w2_qzeros) | |
| set_weight_attrs(w2_qzeros, extra_weight_attrs) | |
| # dont shard the w2 scales when running act order | |
| set_weight_attrs(w2_qzeros, {"load_full_w2": self.quant_config.desc_act}) | |
| w13_g_idx = torch.nn.Parameter( | |
| torch.empty( | |
| num_experts, | |
| hidden_size, | |
| dtype=torch.int32, | |
| ), | |
| requires_grad=False, | |
| ) | |
| layer.register_parameter("w13_g_idx", w13_g_idx) | |
| set_weight_attrs(w13_g_idx, extra_weight_attrs) | |
| w2_g_idx = torch.nn.Parameter( | |
| torch.empty( | |
| num_experts, | |
| intermediate_size_per_partition, | |
| dtype=torch.int32, | |
| ), | |
| requires_grad=False, | |
| ) | |
| layer.register_parameter("w2_g_idx", w2_g_idx) | |
| set_weight_attrs(w2_g_idx, extra_weight_attrs) | |
| w13_g_idx_sort_indices = torch.nn.Parameter( | |
| torch.empty( | |
| num_experts, | |
| hidden_size, | |
| dtype=torch.int32, | |
| ), | |
| requires_grad=False, | |
| ) | |
| layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices) | |
| set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) | |
| w2_g_idx_sort_indices = torch.nn.Parameter( | |
| torch.empty( | |
| num_experts, | |
| intermediate_size_per_partition, | |
| dtype=torch.int32, | |
| ), | |
| requires_grad=False, | |
| ) | |
| layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices) | |
| set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) | |
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | |
| # Process act_order | |
| if self.quant_config.desc_act: | |
| # Get sorting based on g_idx | |
| num_experts = layer.w13_g_idx.shape[0] | |
| w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx) | |
| w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx) | |
| w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx) | |
| w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx) | |
| for e in range(num_experts): | |
| w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to( | |
| torch.int32 | |
| ) | |
| w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to( | |
| torch.int32 | |
| ) | |
| w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]] | |
| w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]] | |
| replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx) | |
| replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx) | |
| replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices) | |
| replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices) | |
| else: | |
| # Reset g_idx related tensors | |
| num_experts = layer.w13_g_idx.shape[0] | |
| device = layer.w13_g_idx.device | |
| layer.w13_g_idx = torch.nn.Parameter( | |
| torch.empty((num_experts, 0), dtype=torch.int32, device=device), | |
| requires_grad=False, | |
| ) | |
| layer.w2_g_idx = torch.nn.Parameter( | |
| torch.empty((num_experts, 0), dtype=torch.int32, device=device), | |
| requires_grad=False, | |
| ) | |
| layer.w13_g_idx_sort_indices = torch.nn.Parameter( | |
| torch.empty((num_experts, 0), dtype=torch.int32, device=device), | |
| requires_grad=False, | |
| ) | |
| layer.w2_g_idx_sort_indices = torch.nn.Parameter( | |
| torch.empty((num_experts, 0), dtype=torch.int32, device=device), | |
| requires_grad=False, | |
| ) | |
| # Repack weights | |
| marlin_w13_qweight = gptq_marlin_moe_repack( | |
| layer.w13_qweight, | |
| layer.w13_g_idx_sort_indices, | |
| layer.w13_qweight.shape[1] * self.quant_config.pack_factor, | |
| layer.w13_qweight.shape[2], | |
| self.quant_config.weight_bits, | |
| ) | |
| replace_parameter(layer, "w13_qweight", marlin_w13_qweight) | |
| marlin_w2_qweight = gptq_marlin_moe_repack( | |
| layer.w2_qweight, | |
| layer.w2_g_idx_sort_indices, | |
| layer.w2_qweight.shape[1] * self.quant_config.pack_factor, | |
| layer.w2_qweight.shape[2], | |
| self.quant_config.weight_bits, | |
| ) | |
| replace_parameter(layer, "w2_qweight", marlin_w2_qweight) | |
| # Repack scales | |
| marlin_w13_scales = marlin_moe_permute_scales( | |
| s=layer.w13_scales, | |
| size_k=layer.intermediate_size_per_partition, | |
| size_n=layer.w13_scales.shape[2], | |
| group_size=self.quant_config.group_size, | |
| ) | |
| replace_parameter(layer, "w13_scales", marlin_w13_scales) | |
| marlin_w2_scales = marlin_moe_permute_scales( | |
| s=layer.w2_scales, | |
| size_k=layer.w2_scales.shape[1] | |
| * ( | |
| self.quant_config.group_size | |
| if self.quant_config.group_size != -1 | |
| else self.quant_config.pack_factor | |
| ), | |
| size_n=layer.w2_scales.shape[2], | |
| group_size=self.quant_config.group_size, | |
| ) | |
| replace_parameter(layer, "w2_scales", marlin_w2_scales) | |
| def create_moe_runner( | |
| self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig | |
| ): | |
| self.moe_runner_config = moe_runner_config | |
| def apply( | |
| self, | |
| layer: torch.nn.Module, | |
| dispatch_output: StandardDispatchOutput, | |
| ) -> CombineInput: | |
| from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput | |
| x = dispatch_output.hidden_states | |
| topk_output = dispatch_output.topk_output | |
| # Delay the import to avoid circular dependency | |
| assert ( | |
| self.moe_runner_config.activation == "silu" | |
| ), "Only SiLU activation is supported." | |
| # The input must currently be float16 | |
| orig_dtype = x.dtype | |
| x = x.half() | |
| topk_weights, topk_ids, router_logits = topk_output | |
| output = fused_marlin_moe( | |
| x, | |
| layer.w13_qweight, | |
| layer.w2_qweight, | |
| layer.w13_scales, | |
| layer.w2_scales, | |
| router_logits, | |
| topk_weights, | |
| topk_ids, | |
| g_idx1=layer.w13_g_idx, | |
| g_idx2=layer.w2_g_idx, | |
| sort_indices1=layer.w13_g_idx_sort_indices, | |
| sort_indices2=layer.w2_g_idx_sort_indices, | |
| num_bits=self.quant_config.weight_bits, | |
| is_k_full=self.is_k_full, | |
| ).to(orig_dtype) | |
| return StandardCombineInput(hidden_states=output) | |
Xet Storage Details
- Size:
- 39.4 kB
- Xet hash:
- 4e4b7261db09d44fac49ee97362a6aa8cbe1572b18825d8cdb63f8b3bd42ab9d
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.