| # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/base_config.py | |
| from __future__ import annotations | |
| import inspect | |
| from abc import ABC, abstractmethod | |
| from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type | |
| import torch | |
| from torch import nn | |
| if TYPE_CHECKING: | |
| from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig | |
| from sglang.srt.layers.moe.token_dispatcher import CombineInput, DispatchOutput | |
| class QuantizeMethodBase(ABC): | |
| """Base class for different quantized methods.""" | |
| def create_weights( | |
| self, layer: torch.nn.Module, *weight_args, **extra_weight_attrs | |
| ): | |
| """Create weights for a layer. | |
| The weights will be set as attributes of the layer.""" | |
| raise NotImplementedError() | |
| def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor: | |
| """Apply the weights in layer to the input tensor. | |
| Expects create_weights to have been called before on the layer.""" | |
| raise NotImplementedError() | |
| def process_weights_after_loading(self, layer: nn.Module) -> None: | |
| """Process the weight after loading. | |
| This can be used for example, to transpose weights for computation. | |
| """ | |
| return | |
| class LinearMethodBase(QuantizeMethodBase): | |
| """Base class for different (maybe quantized) linear methods.""" | |
| def create_weights( | |
| self, | |
| layer: torch.nn.Module, | |
| input_size_per_partition: int, | |
| output_partition_sizes: List[int], | |
| input_size: int, | |
| output_size: int, | |
| params_dtype: torch.dtype, | |
| **extra_weight_attrs, | |
| ): | |
| """Create weights for a linear layer. | |
| The weights will be set as attributes of the layer. | |
| Args: | |
| layer: The layer that is using the LinearMethodBase factory. | |
| input_size_per_partition: Size of the weight input dim on rank X. | |
| output_partition_sizes: Sizes of the output dim of each logical | |
| weight on rank X. E.g., output_partition_sizes for QKVLinear | |
| is a list contains the width of Wq, Wk, Wv on rank X. | |
| input_size: Size of the input dim of the weight across all ranks. | |
| output_size: Size of the output dim of the weight across all ranks. | |
| params_dtype: Datatype of the parameters. | |
| """ | |
| raise NotImplementedError() | |
| def apply( | |
| self, | |
| layer: torch.nn.Module, | |
| x: torch.Tensor, | |
| bias: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """Apply the weights in layer to the input tensor. | |
| Expects create_weights to have been called before on the layer.""" | |
| raise NotImplementedError() | |
| class FusedMoEMethodBase(QuantizeMethodBase): | |
| def create_weights( | |
| self, | |
| layer: torch.nn.Module, | |
| num_experts: int, | |
| hidden_size: int, | |
| intermediate_size_per_partition: int, | |
| params_dtype: torch.dtype, | |
| **extra_weight_attrs, | |
| ): | |
| raise NotImplementedError | |
| def create_moe_runner( | |
| self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig | |
| ): | |
| raise NotImplementedError | |
| def apply( | |
| self, | |
| layer: torch.nn.Module, | |
| dispatch_output: DispatchOutput, | |
| ) -> CombineInput: | |
| raise NotImplementedError | |
| class QuantizationConfig(ABC): | |
| """Base class for quantization configs.""" | |
| def __init__(self): | |
| super().__init__() | |
| # mapping is updated by models as they initialize | |
| self.packed_modules_mapping: Dict[str, List[str]] = dict() | |
| def get_name(self) -> str: | |
| """Name of the quantization method.""" | |
| raise NotImplementedError() | |
| def get_supported_act_dtypes(self) -> List[torch.dtype]: | |
| """List of supported activation dtypes.""" | |
| raise NotImplementedError() | |
| def get_min_capability(cls) -> int: | |
| """Minimum GPU capability to support the quantization method. | |
| E.g., 70 for Volta, 75 for Turing, 80 for Ampere. | |
| This requirement is due to the custom CUDA kernels used by the | |
| quantization method. | |
| """ | |
| raise NotImplementedError() | |
| def get_config_filenames() -> List[str]: | |
| """List of filenames to search for in the model directory.""" | |
| raise NotImplementedError() | |
| def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig": | |
| """Create a config class from the model's quantization config.""" | |
| raise NotImplementedError() | |
| def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: | |
| """ | |
| Detects if this quantization method can support a given checkpoint | |
| format by overriding the user specified quantization method -- | |
| this method should only be overwritten by subclasses in exceptional | |
| circumstances | |
| """ | |
| return None | |
| def _modelopt_override_quantization_method( | |
| cls, hf_quant_config, user_quant | |
| ) -> Optional[str]: | |
| """Shared ModelOpt quantization method override logic.""" | |
| if hf_quant_config is None: | |
| return None | |
| # Check if this is a ModelOpt config | |
| quant_algo = hf_quant_config.get("quant_algo", "").upper() | |
| # If user specified generic "modelopt", auto-detect the specific method | |
| if user_quant == "modelopt": | |
| if "FP8" in quant_algo: | |
| return "modelopt_fp8" | |
| elif "NVFP4" in quant_algo or "FP4" in quant_algo: | |
| return "modelopt_fp4" | |
| return None | |
| def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any: | |
| """Get a value from the model's quantization config.""" | |
| for key in keys: | |
| if key in config: | |
| return config[key] | |
| raise ValueError( | |
| f"Cannot find any of {keys} in the model's " "quantization config." | |
| ) | |
| def get_from_keys_or(config: Dict[str, Any], keys: List[str], default: Any) -> Any: | |
| """Get a optional value from the model's quantization config.""" | |
| try: | |
| return QuantizationConfig.get_from_keys(config, keys) | |
| except ValueError: | |
| return default | |
| def get_quant_method( | |
| self, layer: torch.nn.Module, prefix: str | |
| ) -> Optional[QuantizeMethodBase]: | |
| """Get the quantize method to use for the quantized layer. | |
| Args: | |
| layer: The layer for the quant method. | |
| prefix: The full name of the layer in the state dict | |
| Returns: | |
| The quantize method. None if the given layer doesn't support quant | |
| method. | |
| """ | |
| raise NotImplementedError() | |
| def get_scaled_act_names(self) -> List[str]: | |
| """Returns the activation function names that should be post-scaled. | |
| For now, this is only used by AWQ. | |
| """ | |
| raise NotImplementedError() | |
| def method_has_implemented_embedding(method_class: Type[QuantizeMethodBase]) -> bool: | |
| """ | |
| Not all quant methods have embedding implemented, so we need to check that | |
| it exists for our given method. We check this by making sure the function | |
| has been changed from the base implementation. | |
| """ | |
| base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding", None) | |
| class_embedding = inspect.getattr_static(method_class, "embedding", None) | |
| return class_embedding is not None and class_embedding is not base_embedding | |
Xet Storage Details
- Size:
- 7.9 kB
- Xet hash:
- e9a423d2e674d8c66a0c6965116971c884729cabdf8485d112b88dc8ca07516c
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.